mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
feat: use cache when decoding
This commit is contained in:
parent
2d674624a3
commit
3a42765cab
@ -1,15 +1,15 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
||||
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
||||
from mamba_ssm.utils.generation import InferenceParams
|
||||
from torch import nn
|
||||
from typing import Optional, List, Tuple, Any, Dict
|
||||
from typing import Optional, Tuple, Any
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
import torch.nn.functional as F
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
FastRMSNorm,
|
||||
FastLinear,
|
||||
@ -72,38 +72,22 @@ class MambaBlock(nn.Module):
|
||||
|
||||
# inference_params
|
||||
def forward(self, hidden_states: torch.Tensor, inference_params=None):
|
||||
seqlen = hidden_states.shape[1]
|
||||
_, seqlen, _ = hidden_states.shape
|
||||
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
||||
|
||||
# TODO: use the inference_params to get the previous states when decoding
|
||||
conv_state, ssm_state = None, None
|
||||
if inference_params is not None:
|
||||
if hidden_states.shape[1] == 1:
|
||||
print("Decoding")
|
||||
conv_state, ssm_state = self._get_states_from_cache(inference_params, hidden_states.shape[0])
|
||||
if inference_params.seqlen_offset > 0:
|
||||
# The states are updated inplace
|
||||
out, _conv_state, _ssm_state = self.step(hidden_states, conv_state, ssm_state)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
return out, _conv_state, _ssm_state
|
||||
if inference_params.seqlen_offset > 0:
|
||||
out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state)
|
||||
return out, conv_state, ssm_state
|
||||
|
||||
projected_states = self.in_proj(hidden_states).transpose(1,2)
|
||||
|
||||
x, z = projected_states.chunk(2, dim=1)
|
||||
# Compute short convolution
|
||||
if conv_state is not None:
|
||||
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
||||
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
||||
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
|
||||
if causal_conv1d_fn is None:
|
||||
x = self.act(self.conv1d(x)[..., :seqlen])
|
||||
else:
|
||||
assert self.activation in ["silu", "swish"]
|
||||
x = causal_conv1d_fn(
|
||||
x=x,
|
||||
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
)
|
||||
conv_state = F.pad(x, (self.d_conv - seqlen, 0))
|
||||
x = causal_conv1d_fn(
|
||||
x=x,
|
||||
weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)),
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
)
|
||||
|
||||
# We're careful here about the layout, to avoid extra transposes.
|
||||
# We want dt to have d as the slowest moving dimension
|
||||
@ -114,8 +98,7 @@ class MambaBlock(nn.Module):
|
||||
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
||||
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
||||
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
||||
assert self.activation in ["silu", "swish"]
|
||||
y, last_ssm_state = selective_scan_fn(
|
||||
y, last_state = selective_scan_fn(
|
||||
x,
|
||||
dt,
|
||||
self.negA,
|
||||
@ -125,57 +108,32 @@ class MambaBlock(nn.Module):
|
||||
z=z,
|
||||
delta_bias=self.dt_proj.bias.float(),
|
||||
delta_softplus=True,
|
||||
return_last_state=True, # ssm_state is not None,
|
||||
return_last_state=True,
|
||||
)
|
||||
y = rearrange(y, "b d l -> b l d")
|
||||
attn_outputs = self.out_proj(y)
|
||||
return attn_outputs, conv_state, last_state
|
||||
|
||||
return attn_outputs, conv_state, last_ssm_state
|
||||
|
||||
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
|
||||
assert self.layer_idx is not None
|
||||
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
||||
return conv_state, ssm_state
|
||||
|
||||
def step(self, hidden_states, conv_state, ssm_state):
|
||||
dtype = hidden_states.dtype
|
||||
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
||||
# only support decoding with 1 token at a time
|
||||
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
||||
x, z = xz.chunk(2, dim=-1) # (B D)
|
||||
|
||||
# Conv step
|
||||
if causal_conv1d_update is None:
|
||||
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
||||
conv_state[:, :, -1] = x
|
||||
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
||||
if self.conv1d.bias is not None:
|
||||
x = x + self.conv1d.bias
|
||||
x = self.act(x).to(dtype=dtype)
|
||||
else:
|
||||
x = causal_conv1d_update(
|
||||
x,
|
||||
conv_state,
|
||||
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
)
|
||||
|
||||
x = causal_conv1d_update(
|
||||
x,
|
||||
conv_state,
|
||||
self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)),
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
)
|
||||
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
||||
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||
# Don't add dt_bias here
|
||||
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
|
||||
|
||||
# SSM step
|
||||
# Discretize A and B
|
||||
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
|
||||
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, self.negA))
|
||||
dB = torch.einsum("bd,bn->bdn", dt, B)
|
||||
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
|
||||
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
|
||||
y = y + self.D.to(dtype) * x
|
||||
y = y * self.act(z) # (B D)
|
||||
|
||||
y = selective_state_update(
|
||||
ssm_state, x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
|
||||
)
|
||||
out = self.out_proj(y)
|
||||
# conv and ssm are updated in place but we return them to make the control flow more explicit
|
||||
return out.unsqueeze(1), conv_state, ssm_state
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
@ -187,14 +145,14 @@ class ResidualBlock(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
inference_params: Optional[Any] = None,
|
||||
):
|
||||
residual = hidden_states
|
||||
shape = hidden_states.shape
|
||||
hidden_states, _ = self.layer_norm(hidden_states.view(-1, shape[-1]))
|
||||
hidden_states, _conv_state, last_ssm_state = self.mamba_block(hidden_states.view(*shape), inference_params)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states, _conv_state, last_ssm_state
|
||||
residual = (hidden_states + residual) if residual is not None else hidden_states
|
||||
shape = residual.shape
|
||||
hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1]))
|
||||
hidden_states, conv_state, last_ssm_state = self.mamba_block(hidden_states.view(*shape), inference_params)
|
||||
return hidden_states, residual, conv_state, last_ssm_state
|
||||
|
||||
class MambaModel(nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
@ -208,14 +166,17 @@ class MambaModel(nn.Module):
|
||||
self.lm_head = FastLinear.load(config, f"{prefix}.embedding", weights, bias=False)
|
||||
self.config = config
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, inference_params=None):
|
||||
def forward(self, input_ids: torch.Tensor, inference_params=None, residual=None) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
print("Input ids: ", input_ids)
|
||||
for block in self.blocks:
|
||||
hidden_states, _conv_state, last_ssm_state = block(hidden_states, inference_params)
|
||||
# inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (_conv_state, last_ssm_state)
|
||||
hidden_states, residual, conv_state, ssm_state = block(hidden_states, residual, inference_params)
|
||||
inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (conv_state, ssm_state)
|
||||
|
||||
hidden_states = hidden_states + residual if residual is not None else hidden_states
|
||||
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
|
||||
hidden_states = hidden_states.view(residual.shape)
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
shape = hidden_states.shape
|
||||
final_hidden_states, _ = self.norm_f(hidden_states.view(-1, shape[-1]))
|
||||
return self.lm_head(final_hidden_states.view(*shape)), input_ids, inference_params
|
||||
# update the offset for the next inference using these params
|
||||
inference_params.seqlen_offset += input_ids.size(1)
|
||||
return logits, input_ids, inference_params
|
@ -406,7 +406,7 @@ class Mamba(Model):
|
||||
dtype = input_ids.dtype
|
||||
|
||||
# Inference params
|
||||
seqlen_og = max_seqlen
|
||||
seqlen_og = 0
|
||||
inf_cache = {}
|
||||
lengths_per_sample = torch.ones(batch_size, dtype=torch.int32, device=input_ids.device) * max_seqlen
|
||||
|
||||
@ -592,7 +592,7 @@ class Mamba(Model):
|
||||
generations.append(generation)
|
||||
|
||||
# Update values
|
||||
batch.input_ids[i, 0] = 0 # next_token_id
|
||||
batch.input_ids[i, 0] = next_token_id
|
||||
batch.all_input_ids[i] = all_input_ids
|
||||
batch.input_lengths[i] = new_input_length
|
||||
batch.prefix_offsets[i] = prefix_offset
|
||||
|
Loading…
Reference in New Issue
Block a user