From 2d674624a31672e8cfe5f11e57c503233b5e0f6e Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 1 Feb 2024 05:00:51 +0000 Subject: [PATCH] fix: start to add caching of previous states --- .../models/custom_modeling/mamba_modeling.py | 185 +++++++-- server/text_generation_server/models/mamba.py | 361 +++++++++++++++++- 2 files changed, 487 insertions(+), 59 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py index ca5f9765..5a6a1d2d 100644 --- a/server/text_generation_server/models/custom_modeling/mamba_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -2,8 +2,9 @@ import torch import torch.distributed from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn +from mamba_ssm.utils.generation import InferenceParams from torch import nn -from typing import Optional, List, Tuple, Any +from typing import Optional, List, Tuple, Any, Dict from transformers.configuration_utils import PretrainedConfig import torch.nn.functional as F @@ -11,19 +12,27 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, FastRMSNorm, + FastLinear, ) +from einops import rearrange, repeat +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +import math + class MambaConfig(PretrainedConfig): def __init__( self, vocab_size=50280, d_model=768, + d_state=16, n_layer=32, layer_norm_epsilon=1e-5, tie_word_embeddings=False, pad_token_id=0, bos_token_id=1, eos_token_id=2, + expand=2, + dt_rank="auto", **kwargs, ): self.vocab_size = vocab_size @@ -32,6 +41,9 @@ class MambaConfig(PretrainedConfig): self.d_model = d_model self.d_inner = d_model * 2 self.d_conv = 4 + self.d_state = d_state + self.expand = expand + self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank super().__init__( pad_token_id=pad_token_id, @@ -44,41 +56,127 @@ class MambaConfig(PretrainedConfig): class MambaBlock(nn.Module): def __init__(self, prefix, config, weights): super().__init__() - self.in_proj = TensorParallelColumnLinear.load( - config=config, prefix=f"{prefix}.in_proj", weights=weights, bias=False - ) - # helper for loading weights - self.load_weights(prefix, weights) + self.layer_idx = int(prefix.split(".")[2]) + self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False) + self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False) + self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True) + self.out_proj = FastLinear.load(config, f"{prefix}.out_proj", weights, bias=False) + self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True) + self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float()) + self.D = weights.get_tensor(f"{prefix}.D") + self.activation = "silu" + self.dt_rank = config.dt_rank + self.d_state = config.d_state + self.d_conv = config.d_conv + self.act = nn.SiLU() - def load_weights(self, prefix, weights): - weight_names = ["x_proj.weight", "dt_proj.weight", "dt_proj.bias", - "out_proj.weight", "in_proj.weight", - "conv1d.weight", "conv1d.bias", "A_log", "D"] - for name in weight_names: - param_name = name.replace('.', '_') - setattr(self, param_name, nn.Parameter(weights.get_tensor(f"{prefix}.{name}"))) - self.out_proj_bias = None - self.negA = -torch.exp(self.A_log.float()) + # inference_params + def forward(self, hidden_states: torch.Tensor, inference_params=None): + seqlen = hidden_states.shape[1] + + # TODO: use the inference_params to get the previous states when decoding + conv_state, ssm_state = None, None + if inference_params is not None: + if hidden_states.shape[1] == 1: + print("Decoding") + conv_state, ssm_state = self._get_states_from_cache(inference_params, hidden_states.shape[0]) + if inference_params.seqlen_offset > 0: + # The states are updated inplace + out, _conv_state, _ssm_state = self.step(hidden_states, conv_state, ssm_state) + # import ipdb; ipdb.set_trace() + return out, _conv_state, _ssm_state - def forward(self, hidden_states: torch.Tensor): projected_states = self.in_proj(hidden_states).transpose(1,2) - # conv1d, ssm, and selective_scan are all fused into one kernel - attn_outputs = mamba_inner_fn( - projected_states, - self.conv1d_weight, - self.conv1d_bias, - self.x_proj_weight, - self.dt_proj_weight, - self.out_proj_weight, - self.out_proj_bias, + + x, z = projected_states.chunk(2, dim=1) + # Compute short convolution + if conv_state is not None: + # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W) + if causal_conv1d_fn is None: + x = self.act(self.conv1d(x)[..., :seqlen]) + else: + assert self.activation in ["silu", "swish"] + x = causal_conv1d_fn( + x=x, + weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), + bias=self.conv1d.bias, + activation=self.activation, + ) + + # We're careful here about the layout, to avoid extra transposes. + # We want dt to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) + dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt = self.dt_proj.weight @ dt.t() + dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) + B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + assert self.activation in ["silu", "swish"] + y, last_ssm_state = selective_scan_fn( + x, + dt, self.negA, - None, - None, + B, + C, self.D.float(), - delta_bias=self.dt_proj_bias.float(), + z=z, + delta_bias=self.dt_proj.bias.float(), delta_softplus=True, + return_last_state=True, # ssm_state is not None, ) - return attn_outputs + y = rearrange(y, "b d l -> b l d") + attn_outputs = self.out_proj(y) + + return attn_outputs, conv_state, last_ssm_state + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + assert self.layer_idx is not None + conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] + return conv_state, ssm_state + + def step(self, hidden_states, conv_state, ssm_state): + dtype = hidden_states.dtype + assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" + xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) + x, z = xz.chunk(2, dim=-1) # (B D) + + # Conv step + if causal_conv1d_update is None: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = x + x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv1d.bias is not None: + x = x + self.conv1d.bias + x = self.act(x).to(dtype=dtype) + else: + x = causal_conv1d_update( + x, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation, + ) + + x_db = self.x_proj(x) # (B dt_rank+2*d_state) + dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) + # Don't add dt_bias here + dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) + + # SSM step + # Discretize A and B + dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) + dA = torch.exp(torch.einsum("bd,dn->bdn", dt, self.negA)) + dB = torch.einsum("bd,bn->bdn", dt, B) + ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) + y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) + y = y + self.D.to(dtype) * x + y = y * self.act(z) # (B D) + + out = self.out_proj(y) + return out.unsqueeze(1), conv_state, ssm_state class ResidualBlock(nn.Module): def __init__(self, layer_id, config, weights): @@ -89,30 +187,35 @@ class ResidualBlock(nn.Module): def forward( self, hidden_states: torch.Tensor, - ): + inference_params: Optional[Any] = None, + ): residual = hidden_states - hidden_states, _ = self.layer_norm(hidden_states.squeeze(0)) - hidden_states = residual + self.mamba_block(hidden_states.unsqueeze(0)) - return hidden_states + shape = hidden_states.shape + hidden_states, _ = self.layer_norm(hidden_states.view(-1, shape[-1])) + hidden_states, _conv_state, last_ssm_state = self.mamba_block(hidden_states.view(*shape), inference_params) + hidden_states = residual + hidden_states + return hidden_states, _conv_state, last_ssm_state class MambaModel(nn.Module): def __init__(self, config, weights): super().__init__() - self.tp_rank = weights.process_group.rank() - self.tp_world_size = weights.process_group.size() prefix = "backbone" - self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) self.blocks = nn.ModuleList( [ResidualBlock(f"{prefix}.layers.{i}", config, weights) for i in range(config.n_layer)] ) self.norm_f = FastRMSNorm.load(f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon) - self.lm_head = TensorParallelColumnLinear.load(config, f"{prefix}.embedding", weights, False) + self.lm_head = FastLinear.load(config, f"{prefix}.embedding", weights, bias=False) + self.config = config - def forward(self, input_ids: torch.Tensor): + def forward(self, input_ids: torch.Tensor, inference_params=None): hidden_states = self.embed_tokens(input_ids) + print("Input ids: ", input_ids) for block in self.blocks: - hidden_states = block(hidden_states) + hidden_states, _conv_state, last_ssm_state = block(hidden_states, inference_params) + # inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (_conv_state, last_ssm_state) - final_hidden_states, _ = self.norm_f(hidden_states.squeeze(0)) - return self.lm_head(final_hidden_states.unsqueeze(0)), input_ids + + shape = hidden_states.shape + final_hidden_states, _ = self.norm_f(hidden_states.view(-1, shape[-1])) + return self.lm_head(final_hidden_states.view(*shape)), input_ids, inference_params diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 05a5b99e..efa233e6 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -1,11 +1,7 @@ import torch import torch.distributed - from transformers import AutoTokenizer, PreTrainedTokenizerBase from typing import Optional - -from text_generation_server.models import CausalLM -from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.models.custom_modeling.mamba_modeling import ( MambaConfig, ) @@ -15,11 +11,10 @@ from text_generation_server.utils import ( weight_files, Weights, ) - import time from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel from text_generation_server.models import Model -from typing import Any, List, Optional, Tuple, Type +from typing import Any, List, Optional, Tuple, Type, Dict from text_generation_server.models.types import ( Batch, Tokens, @@ -27,15 +22,55 @@ from text_generation_server.models.types import ( GeneratedText, ) from text_generation_server.utils.tokens import batch_top_tokens, Sampling +from dataclasses import dataclass +from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +from mamba_ssm.utils.generation import InferenceParams +@dataclass +class MambaBatch(Batch): + batch_id: int + requests: List[generate_pb2.Request] + requests_idx_mapping: Dict[int, int] -class MambaCausalLMBatch(CausalLMBatch): + # Decoder values + input_ids: torch.Tensor past_input_ids: Optional[torch.Tensor] - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.past_input_ids = None + # All tokens + all_input_ids: List[torch.Tensor] + # Lengths of all generations present in the batch + input_lengths: List[int] + prefix_offsets: List[int] + read_offsets: List[int] + + # Generation helpers + next_token_choosers: List[NextTokenChooser] + stopping_criterias: List[StoppingCriteria] + top_n_tokens: List[int] + top_n_tokens_tensor: torch.Tensor + + # Metadata used for padding + max_input_length: int + padding_right_offset: int + + # Maximum number of tokens this batch will grow to + max_tokens: int + + # Past metadata + keys_head_dim_last: bool = True + + # Inference params + inference_params: Optional[Dict[str, Any]] = None + + def to_pb(self) -> generate_pb2.CachedBatch: + return generate_pb2.CachedBatch( + id=self.batch_id, + request_ids=[r.id for r in self.requests], + size=len(self), + max_tokens=self.max_tokens, + ) + @classmethod def from_pb( cls, @@ -43,11 +78,256 @@ class MambaCausalLMBatch(CausalLMBatch): tokenizer: PreTrainedTokenizerBase, dtype: torch.dtype, device: torch.device, - ) -> "CausalLMBatch": - batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device) - batch.keys_head_dim_last = False - return batch + ) -> "MambaBatch": + inputs = [] + next_token_choosers = [] + stopping_criterias = [] + top_n_tokens = [] + prefix_offsets = [] + read_offsets = [] + requests_idx_mapping = {} + # Parse batch + max_truncation = 0 + padding_right_offset = 0 + max_decode_tokens = 0 + for i, r in enumerate(pb.requests): + requests_idx_mapping[r.id] = i + inputs.append(r.inputs) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + stopping_criteria = StoppingCriteria.from_pb( + r.stopping_parameters, tokenizer + ) + stopping_criterias.append(stopping_criteria) + top_n_tokens.append(r.top_n_tokens) + max_truncation = max(max_truncation, r.truncate) + max_decode_tokens += stopping_criteria.max_new_tokens + padding_right_offset = max( + padding_right_offset, stopping_criteria.max_new_tokens + ) + + tokenized_inputs = tokenizer( + inputs, + return_tensors="pt", + padding=True, + return_token_type_ids=False, + truncation=True, + max_length=max_truncation, + ).to(device) + for _ in pb.requests: + input_len = tokenized_inputs["input_ids"].shape[1] + prefix_offsets.append(input_len - 5) + read_offsets.append(input_len) + + input_lengths = tokenized_inputs["attention_mask"].sum(1) + max_input_length = input_lengths.max() + input_ids = tokenized_inputs["input_ids"] + all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) + max_tokens = len(inputs) * (max_input_length + max_decode_tokens) + return cls( + batch_id=pb.id, + requests=pb.requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + past_input_ids=None, + all_input_ids=list(all_input_ids), + input_lengths=input_lengths.tolist(), + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + max_input_length=max_input_length.item(), + padding_right_offset=padding_right_offset, + max_tokens=max_tokens, + ) + + def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]: + if len(request_ids) == 0: + raise ValueError("Batch must have at least one request") + if len(request_ids) == len(self): + return self + + keep_indices = [] + + # New values after filtering + requests_idx_mapping = {} + requests = [] + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + max_input_length = 0 + + next_token_choosers = [] + stopping_criterias = [] + top_n_tokens = [] + + total_remaining_decode_tokens = 0 + new_padding_right_offset = 0 + + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + requests_idx_mapping[request_id] = i + keep_indices.append(idx) + + requests.append(self.requests[idx]) + prefix_offsets.append(self.prefix_offsets[idx]) + read_offsets.append(self.read_offsets[idx]) + all_input_ids.append(self.all_input_ids[idx]) + + request_input_length = self.input_lengths[idx] + input_lengths.append(request_input_length) + max_input_length = max(max_input_length, request_input_length) + + next_token_choosers.append(self.next_token_choosers[idx]) + stopping_criteria = self.stopping_criterias[idx] + stopping_criterias.append(stopping_criteria) + top_n_tokens.append(self.top_n_tokens[idx]) + remaining_decode_tokens = ( + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + ) + total_remaining_decode_tokens += remaining_decode_tokens + new_padding_right_offset = max( + new_padding_right_offset, remaining_decode_tokens + ) + + # Apply indices to input_ids, attention mask, past key values and other items that need to be cached + input_ids = self.input_ids[keep_indices] + position_ids = self.position_ids[keep_indices] + + top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices] + max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens + + self.requests = requests + self.requests_idx_mapping = requests_idx_mapping + self.input_ids = input_ids + self.all_input_ids = all_input_ids + self.input_lengths = input_lengths + self.prefix_offsets = prefix_offsets + self.read_offsets = read_offsets + self.next_token_choosers = next_token_choosers + self.stopping_criterias = stopping_criterias + self.top_n_tokens = top_n_tokens + self.top_n_tokens_tensor = top_n_tokens_tensor + self.max_input_length = max_input_length + self.padding_right_offset = new_padding_right_offset + self.max_tokens = max_tokens + + return self + + @classmethod + def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch": + # Used for padding + total_batch_size = 0 + max_input_length = 0 + padding_right_offset = 0 + for batch in batches: + total_batch_size += len(batch) + max_input_length = max(max_input_length, batch.max_input_length) + padding_right_offset = max(padding_right_offset, batch.padding_right_offset) + + # Batch attributes + requests = [] + requests_idx_mapping = {} + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + next_token_choosers = [] + stopping_criterias = [] + top_n_tokens = [] + max_tokens = 0 + + # Batch tensors + input_ids = None + attention_mask = None + position_ids = None + past_key_values = [] + top_n_tokens_tensor = None + + # Used for slicing correctly inside the tensors + # Equivalent to a cumsum on batch sizes + start_index = 0 + for i, batch in enumerate(batches): + requests.extend(batch.requests) + input_lengths.extend(batch.input_lengths) + prefix_offsets.extend(batch.prefix_offsets) + read_offsets.extend(batch.read_offsets) + all_input_ids.extend(batch.all_input_ids) + next_token_choosers.extend(batch.next_token_choosers) + stopping_criterias.extend(batch.stopping_criterias) + top_n_tokens.extend(batch.top_n_tokens) + + if i == 0: + requests_idx_mapping = batch.requests_idx_mapping + else: + # We need to offset the mapping for each batch by the cumulative batch size + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + start_index + + # Slicing end index for this batch + end_index = start_index + len(batch) + + # We only concatenate batches that did at least one step + if batch.past_key_values is None: + raise ValueError("only concatenate prefilled batches") + + # Create empty tensor + # input_ids is always of shape [batch_size, 1] + # We do not need to pad it + if input_ids is None: + input_ids = batch.input_ids.new_empty((total_batch_size, 1)) + # Copy to correct indices + input_ids[start_index:end_index] = batch.input_ids + + # Create padded tensor + if attention_mask is None: + attention_mask = batch.attention_mask.new_zeros( + (total_batch_size, max_input_length + padding_right_offset), + ) + + if top_n_tokens_tensor is None: + top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( + total_batch_size, + ) + top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor + + # Add eventual padding tokens that were added while concatenating + max_tokens += batch.max_tokens + ( + max_input_length - batch.max_input_length + ) * len(batch) + + start_index = end_index + + return cls( + batch_id=batches[0].batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + all_input_ids=all_input_ids, + input_lengths=input_lengths, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + max_input_length=max_input_length, + padding_right_offset=padding_right_offset, + keys_head_dim_last=batches[0].keys_head_dim_last, + max_tokens=max_tokens, + ) + + def __len__(self): + return len(self.requests) class Mamba(Model): def __init__( @@ -99,8 +379,8 @@ class Mamba(Model): ) @property - def batch_type(self) -> Type[CausalLMBatch]: - return MambaCausalLMBatch + def batch_type(self) -> Type[MambaBatch]: + return MambaBatch def warmup(self, batch) -> Optional[int]: # TODO: implement warmup for Mamba if needed @@ -119,10 +399,50 @@ class Mamba(Model): def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]: start = time.time_ns() - input_ids = batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids + input_ids = batch.input_ids # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids - logits, past_input_ids = self.model(input_ids)[:2] + batch_size = input_ids.shape[0] + max_seqlen = input_ids.shape[1] + dtype = input_ids.dtype + # Inference params + seqlen_og = max_seqlen + inf_cache = {} + lengths_per_sample = torch.ones(batch_size, dtype=torch.int32, device=input_ids.device) * max_seqlen + + if batch.inference_params is None: + inference_params = InferenceParams( + max_seqlen=max_seqlen, + max_batch_size=batch_size, + seqlen_offset=seqlen_og, + key_value_memory_dict=inf_cache, + lengths_per_sample=lengths_per_sample, + ) + + # Allocate inference cache + for res_block in self.model.blocks: + block = res_block.mamba_block + conv_state = torch.zeros( + batch_size, + self.model.config.d_model * self.model.config.expand, + self.model.config.d_conv, + device=block.conv1d.weight.device, + dtype=block.conv1d.weight.dtype, + ) + ssm_state = torch.zeros( + batch_size, + self.model.config.d_model * self.model.config.expand, + self.model.config.d_state, + device=block.dt_proj.weight.device, + dtype=block.dt_proj.weight.dtype, + ) + inference_params.key_value_memory_dict[block.layer_idx] = (conv_state, ssm_state) + batch.inference_params = inference_params + + # Forward pass + logits, past_input_ids, new_inference_params = self.model(input_ids, batch.inference_params) + + batch.inference_params = new_inference_params # Results generations: List[Generation] = [] stopped = True @@ -272,6 +592,7 @@ class Mamba(Model): generations.append(generation) # Update values + batch.input_ids[i, 0] = 0 # next_token_id batch.all_input_ids[i] = all_input_ids batch.input_lengths[i] = new_input_length batch.prefix_offsets[i] = prefix_offset @@ -284,6 +605,10 @@ class Mamba(Model): decode_ns = time.time_ns() - start_decode return generations, None, (forward_ns, decode_ns) + # Slice unused values from prefill + batch.input_ids = batch.input_ids[:, :1] + + batch.past_input_ids = past_input_ids forward_ns = start_decode - start