From 3a42765cab1fd5ba3d78d17bfaac8741518e88c9 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 2 Feb 2024 21:50:51 +0000 Subject: [PATCH] feat: use cache when decoding --- .../models/custom_modeling/mamba_modeling.py | 129 ++++++------------ server/text_generation_server/models/mamba.py | 4 +- 2 files changed, 47 insertions(+), 86 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py index 5a6a1d2d..3180d686 100644 --- a/server/text_generation_server/models/custom_modeling/mamba_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -1,15 +1,15 @@ import torch import torch.distributed +from mamba_ssm.ops.triton.selective_state_update import selective_state_update from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn from mamba_ssm.utils.generation import InferenceParams from torch import nn -from typing import Optional, List, Tuple, Any, Dict +from typing import Optional, Tuple, Any from transformers.configuration_utils import PretrainedConfig import torch.nn.functional as F from text_generation_server.utils.layers import ( - TensorParallelColumnLinear, TensorParallelEmbedding, FastRMSNorm, FastLinear, @@ -72,38 +72,22 @@ class MambaBlock(nn.Module): # inference_params def forward(self, hidden_states: torch.Tensor, inference_params=None): - seqlen = hidden_states.shape[1] + _, seqlen, _ = hidden_states.shape + conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] - # TODO: use the inference_params to get the previous states when decoding - conv_state, ssm_state = None, None - if inference_params is not None: - if hidden_states.shape[1] == 1: - print("Decoding") - conv_state, ssm_state = self._get_states_from_cache(inference_params, hidden_states.shape[0]) - if inference_params.seqlen_offset > 0: - # The states are updated inplace - out, _conv_state, _ssm_state = self.step(hidden_states, conv_state, ssm_state) - # import ipdb; ipdb.set_trace() - return out, _conv_state, _ssm_state + if inference_params.seqlen_offset > 0: + out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state) + return out, conv_state, ssm_state projected_states = self.in_proj(hidden_states).transpose(1,2) - x, z = projected_states.chunk(2, dim=1) - # Compute short convolution - if conv_state is not None: - # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W) - if causal_conv1d_fn is None: - x = self.act(self.conv1d(x)[..., :seqlen]) - else: - assert self.activation in ["silu", "swish"] - x = causal_conv1d_fn( - x=x, - weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), - bias=self.conv1d.bias, - activation=self.activation, - ) + conv_state = F.pad(x, (self.d_conv - seqlen, 0)) + x = causal_conv1d_fn( + x=x, + weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), + bias=self.conv1d.bias, + activation=self.activation, + ) # We're careful here about the layout, to avoid extra transposes. # We want dt to have d as the slowest moving dimension @@ -114,8 +98,7 @@ class MambaBlock(nn.Module): dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() - assert self.activation in ["silu", "swish"] - y, last_ssm_state = selective_scan_fn( + y, last_state = selective_scan_fn( x, dt, self.negA, @@ -125,57 +108,32 @@ class MambaBlock(nn.Module): z=z, delta_bias=self.dt_proj.bias.float(), delta_softplus=True, - return_last_state=True, # ssm_state is not None, + return_last_state=True, ) y = rearrange(y, "b d l -> b l d") attn_outputs = self.out_proj(y) + return attn_outputs, conv_state, last_state - return attn_outputs, conv_state, last_ssm_state - - def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): - assert self.layer_idx is not None - conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] - return conv_state, ssm_state - def step(self, hidden_states, conv_state, ssm_state): - dtype = hidden_states.dtype - assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" + # only support decoding with 1 token at a time xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) x, z = xz.chunk(2, dim=-1) # (B D) - - # Conv step - if causal_conv1d_update is None: - conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = x - x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) - if self.conv1d.bias is not None: - x = x + self.conv1d.bias - x = self.act(x).to(dtype=dtype) - else: - x = causal_conv1d_update( - x, - conv_state, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation, - ) - + x = causal_conv1d_update( + x, + conv_state, + self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), + self.conv1d.bias, + self.activation, + ) x_db = self.x_proj(x) # (B dt_rank+2*d_state) dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) # Don't add dt_bias here dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) - - # SSM step - # Discretize A and B - dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) - dA = torch.exp(torch.einsum("bd,dn->bdn", dt, self.negA)) - dB = torch.einsum("bd,bn->bdn", dt, B) - ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) - y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) - y = y + self.D.to(dtype) * x - y = y * self.act(z) # (B D) - + y = selective_state_update( + ssm_state, x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True + ) out = self.out_proj(y) + # conv and ssm are updated in place but we return them to make the control flow more explicit return out.unsqueeze(1), conv_state, ssm_state class ResidualBlock(nn.Module): @@ -187,14 +145,14 @@ class ResidualBlock(nn.Module): def forward( self, hidden_states: torch.Tensor, + residual: Optional[torch.Tensor] = None, inference_params: Optional[Any] = None, ): - residual = hidden_states - shape = hidden_states.shape - hidden_states, _ = self.layer_norm(hidden_states.view(-1, shape[-1])) - hidden_states, _conv_state, last_ssm_state = self.mamba_block(hidden_states.view(*shape), inference_params) - hidden_states = residual + hidden_states - return hidden_states, _conv_state, last_ssm_state + residual = (hidden_states + residual) if residual is not None else hidden_states + shape = residual.shape + hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1])) + hidden_states, conv_state, last_ssm_state = self.mamba_block(hidden_states.view(*shape), inference_params) + return hidden_states, residual, conv_state, last_ssm_state class MambaModel(nn.Module): def __init__(self, config, weights): @@ -208,14 +166,17 @@ class MambaModel(nn.Module): self.lm_head = FastLinear.load(config, f"{prefix}.embedding", weights, bias=False) self.config = config - def forward(self, input_ids: torch.Tensor, inference_params=None): + def forward(self, input_ids: torch.Tensor, inference_params=None, residual=None) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]: hidden_states = self.embed_tokens(input_ids) - print("Input ids: ", input_ids) for block in self.blocks: - hidden_states, _conv_state, last_ssm_state = block(hidden_states, inference_params) - # inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (_conv_state, last_ssm_state) + hidden_states, residual, conv_state, ssm_state = block(hidden_states, residual, inference_params) + inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (conv_state, ssm_state) + hidden_states = hidden_states + residual if residual is not None else hidden_states + hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1))) + hidden_states = hidden_states.view(residual.shape) + logits = self.lm_head(hidden_states) - shape = hidden_states.shape - final_hidden_states, _ = self.norm_f(hidden_states.view(-1, shape[-1])) - return self.lm_head(final_hidden_states.view(*shape)), input_ids, inference_params + # update the offset for the next inference using these params + inference_params.seqlen_offset += input_ids.size(1) + return logits, input_ids, inference_params \ No newline at end of file diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index efa233e6..f7d950e7 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -406,7 +406,7 @@ class Mamba(Model): dtype = input_ids.dtype # Inference params - seqlen_og = max_seqlen + seqlen_og = 0 inf_cache = {} lengths_per_sample = torch.ones(batch_size, dtype=torch.int32, device=input_ids.device) * max_seqlen @@ -592,7 +592,7 @@ class Mamba(Model): generations.append(generation) # Update values - batch.input_ids[i, 0] = 0 # next_token_id + batch.input_ids[i, 0] = next_token_id batch.all_input_ids[i] = all_input_ids batch.input_lengths[i] = new_input_length batch.prefix_offsets[i] = prefix_offset