feat: use cache when decoding

This commit is contained in:
drbh 2024-02-02 21:50:51 +00:00
parent 2d674624a3
commit 3a42765cab
2 changed files with 47 additions and 86 deletions

View File

@ -1,15 +1,15 @@
import torch import torch
import torch.distributed import torch.distributed
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
from mamba_ssm.utils.generation import InferenceParams from mamba_ssm.utils.generation import InferenceParams
from torch import nn from torch import nn
from typing import Optional, List, Tuple, Any, Dict from typing import Optional, Tuple, Any
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
import torch.nn.functional as F import torch.nn.functional as F
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
FastRMSNorm, FastRMSNorm,
FastLinear, FastLinear,
@ -72,35 +72,19 @@ class MambaBlock(nn.Module):
# inference_params # inference_params
def forward(self, hidden_states: torch.Tensor, inference_params=None): def forward(self, hidden_states: torch.Tensor, inference_params=None):
seqlen = hidden_states.shape[1] _, seqlen, _ = hidden_states.shape
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
# TODO: use the inference_params to get the previous states when decoding
conv_state, ssm_state = None, None
if inference_params is not None:
if hidden_states.shape[1] == 1:
print("Decoding")
conv_state, ssm_state = self._get_states_from_cache(inference_params, hidden_states.shape[0])
if inference_params.seqlen_offset > 0: if inference_params.seqlen_offset > 0:
# The states are updated inplace out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state)
out, _conv_state, _ssm_state = self.step(hidden_states, conv_state, ssm_state) return out, conv_state, ssm_state
# import ipdb; ipdb.set_trace()
return out, _conv_state, _ssm_state
projected_states = self.in_proj(hidden_states).transpose(1,2) projected_states = self.in_proj(hidden_states).transpose(1,2)
x, z = projected_states.chunk(2, dim=1) x, z = projected_states.chunk(2, dim=1)
# Compute short convolution conv_state = F.pad(x, (self.d_conv - seqlen, 0))
if conv_state is not None:
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
if causal_conv1d_fn is None:
x = self.act(self.conv1d(x)[..., :seqlen])
else:
assert self.activation in ["silu", "swish"]
x = causal_conv1d_fn( x = causal_conv1d_fn(
x=x, x=x,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)),
bias=self.conv1d.bias, bias=self.conv1d.bias,
activation=self.activation, activation=self.activation,
) )
@ -114,8 +98,7 @@ class MambaBlock(nn.Module):
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
assert self.activation in ["silu", "swish"] y, last_state = selective_scan_fn(
y, last_ssm_state = selective_scan_fn(
x, x,
dt, dt,
self.negA, self.negA,
@ -125,57 +108,32 @@ class MambaBlock(nn.Module):
z=z, z=z,
delta_bias=self.dt_proj.bias.float(), delta_bias=self.dt_proj.bias.float(),
delta_softplus=True, delta_softplus=True,
return_last_state=True, # ssm_state is not None, return_last_state=True,
) )
y = rearrange(y, "b d l -> b l d") y = rearrange(y, "b d l -> b l d")
attn_outputs = self.out_proj(y) attn_outputs = self.out_proj(y)
return attn_outputs, conv_state, last_state
return attn_outputs, conv_state, last_ssm_state
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
assert self.layer_idx is not None
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
return conv_state, ssm_state
def step(self, hidden_states, conv_state, ssm_state): def step(self, hidden_states, conv_state, ssm_state):
dtype = hidden_states.dtype # only support decoding with 1 token at a time
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
x, z = xz.chunk(2, dim=-1) # (B D) x, z = xz.chunk(2, dim=-1) # (B D)
# Conv step
if causal_conv1d_update is None:
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
conv_state[:, :, -1] = x
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
if self.conv1d.bias is not None:
x = x + self.conv1d.bias
x = self.act(x).to(dtype=dtype)
else:
x = causal_conv1d_update( x = causal_conv1d_update(
x, x,
conv_state, conv_state,
rearrange(self.conv1d.weight, "d 1 w -> d w"), self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)),
self.conv1d.bias, self.conv1d.bias,
self.activation, self.activation,
) )
x_db = self.x_proj(x) # (B dt_rank+2*d_state) x_db = self.x_proj(x) # (B dt_rank+2*d_state)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
# Don't add dt_bias here # Don't add dt_bias here
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
y = selective_state_update(
# SSM step ssm_state, x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
# Discretize A and B )
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, self.negA))
dB = torch.einsum("bd,bn->bdn", dt, B)
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
y = y + self.D.to(dtype) * x
y = y * self.act(z) # (B D)
out = self.out_proj(y) out = self.out_proj(y)
# conv and ssm are updated in place but we return them to make the control flow more explicit
return out.unsqueeze(1), conv_state, ssm_state return out.unsqueeze(1), conv_state, ssm_state
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
@ -187,14 +145,14 @@ class ResidualBlock(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = None,
inference_params: Optional[Any] = None, inference_params: Optional[Any] = None,
): ):
residual = hidden_states residual = (hidden_states + residual) if residual is not None else hidden_states
shape = hidden_states.shape shape = residual.shape
hidden_states, _ = self.layer_norm(hidden_states.view(-1, shape[-1])) hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1]))
hidden_states, _conv_state, last_ssm_state = self.mamba_block(hidden_states.view(*shape), inference_params) hidden_states, conv_state, last_ssm_state = self.mamba_block(hidden_states.view(*shape), inference_params)
hidden_states = residual + hidden_states return hidden_states, residual, conv_state, last_ssm_state
return hidden_states, _conv_state, last_ssm_state
class MambaModel(nn.Module): class MambaModel(nn.Module):
def __init__(self, config, weights): def __init__(self, config, weights):
@ -208,14 +166,17 @@ class MambaModel(nn.Module):
self.lm_head = FastLinear.load(config, f"{prefix}.embedding", weights, bias=False) self.lm_head = FastLinear.load(config, f"{prefix}.embedding", weights, bias=False)
self.config = config self.config = config
def forward(self, input_ids: torch.Tensor, inference_params=None): def forward(self, input_ids: torch.Tensor, inference_params=None, residual=None) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
print("Input ids: ", input_ids)
for block in self.blocks: for block in self.blocks:
hidden_states, _conv_state, last_ssm_state = block(hidden_states, inference_params) hidden_states, residual, conv_state, ssm_state = block(hidden_states, residual, inference_params)
# inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (_conv_state, last_ssm_state) inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (conv_state, ssm_state)
hidden_states = hidden_states + residual if residual is not None else hidden_states
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
hidden_states = hidden_states.view(residual.shape)
logits = self.lm_head(hidden_states)
shape = hidden_states.shape # update the offset for the next inference using these params
final_hidden_states, _ = self.norm_f(hidden_states.view(-1, shape[-1])) inference_params.seqlen_offset += input_ids.size(1)
return self.lm_head(final_hidden_states.view(*shape)), input_ids, inference_params return logits, input_ids, inference_params

View File

@ -406,7 +406,7 @@ class Mamba(Model):
dtype = input_ids.dtype dtype = input_ids.dtype
# Inference params # Inference params
seqlen_og = max_seqlen seqlen_og = 0
inf_cache = {} inf_cache = {}
lengths_per_sample = torch.ones(batch_size, dtype=torch.int32, device=input_ids.device) * max_seqlen lengths_per_sample = torch.ones(batch_size, dtype=torch.int32, device=input_ids.device) * max_seqlen
@ -592,7 +592,7 @@ class Mamba(Model):
generations.append(generation) generations.append(generation)
# Update values # Update values
batch.input_ids[i, 0] = 0 # next_token_id batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length batch.input_lengths[i] = new_input_length
batch.prefix_offsets[i] = prefix_offset batch.prefix_offsets[i] = prefix_offset