mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: start to add caching of previous states
This commit is contained in:
parent
5b6f9259c1
commit
2d674624a3
@ -2,8 +2,9 @@ import torch
|
|||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
||||||
|
from mamba_ssm.utils.generation import InferenceParams
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from typing import Optional, List, Tuple, Any
|
from typing import Optional, List, Tuple, Any, Dict
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
@ -11,19 +12,27 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
|
FastLinear,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||||
|
import math
|
||||||
|
|
||||||
class MambaConfig(PretrainedConfig):
|
class MambaConfig(PretrainedConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_size=50280,
|
vocab_size=50280,
|
||||||
d_model=768,
|
d_model=768,
|
||||||
|
d_state=16,
|
||||||
n_layer=32,
|
n_layer=32,
|
||||||
layer_norm_epsilon=1e-5,
|
layer_norm_epsilon=1e-5,
|
||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
bos_token_id=1,
|
bos_token_id=1,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
|
expand=2,
|
||||||
|
dt_rank="auto",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
@ -32,6 +41,9 @@ class MambaConfig(PretrainedConfig):
|
|||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.d_inner = d_model * 2
|
self.d_inner = d_model * 2
|
||||||
self.d_conv = 4
|
self.d_conv = 4
|
||||||
|
self.d_state = d_state
|
||||||
|
self.expand = expand
|
||||||
|
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
@ -44,41 +56,127 @@ class MambaConfig(PretrainedConfig):
|
|||||||
class MambaBlock(nn.Module):
|
class MambaBlock(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_proj = TensorParallelColumnLinear.load(
|
self.layer_idx = int(prefix.split(".")[2])
|
||||||
config=config, prefix=f"{prefix}.in_proj", weights=weights, bias=False
|
self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False)
|
||||||
)
|
self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False)
|
||||||
# helper for loading weights
|
self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True)
|
||||||
self.load_weights(prefix, weights)
|
self.out_proj = FastLinear.load(config, f"{prefix}.out_proj", weights, bias=False)
|
||||||
|
self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True)
|
||||||
|
self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float())
|
||||||
|
self.D = weights.get_tensor(f"{prefix}.D")
|
||||||
|
self.activation = "silu"
|
||||||
|
self.dt_rank = config.dt_rank
|
||||||
|
self.d_state = config.d_state
|
||||||
|
self.d_conv = config.d_conv
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
|
||||||
def load_weights(self, prefix, weights):
|
# inference_params
|
||||||
weight_names = ["x_proj.weight", "dt_proj.weight", "dt_proj.bias",
|
def forward(self, hidden_states: torch.Tensor, inference_params=None):
|
||||||
"out_proj.weight", "in_proj.weight",
|
seqlen = hidden_states.shape[1]
|
||||||
"conv1d.weight", "conv1d.bias", "A_log", "D"]
|
|
||||||
for name in weight_names:
|
# TODO: use the inference_params to get the previous states when decoding
|
||||||
param_name = name.replace('.', '_')
|
conv_state, ssm_state = None, None
|
||||||
setattr(self, param_name, nn.Parameter(weights.get_tensor(f"{prefix}.{name}")))
|
if inference_params is not None:
|
||||||
self.out_proj_bias = None
|
if hidden_states.shape[1] == 1:
|
||||||
self.negA = -torch.exp(self.A_log.float())
|
print("Decoding")
|
||||||
|
conv_state, ssm_state = self._get_states_from_cache(inference_params, hidden_states.shape[0])
|
||||||
|
if inference_params.seqlen_offset > 0:
|
||||||
|
# The states are updated inplace
|
||||||
|
out, _conv_state, _ssm_state = self.step(hidden_states, conv_state, ssm_state)
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
return out, _conv_state, _ssm_state
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor):
|
|
||||||
projected_states = self.in_proj(hidden_states).transpose(1,2)
|
projected_states = self.in_proj(hidden_states).transpose(1,2)
|
||||||
# conv1d, ssm, and selective_scan are all fused into one kernel
|
|
||||||
attn_outputs = mamba_inner_fn(
|
x, z = projected_states.chunk(2, dim=1)
|
||||||
projected_states,
|
# Compute short convolution
|
||||||
self.conv1d_weight,
|
if conv_state is not None:
|
||||||
self.conv1d_bias,
|
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
||||||
self.x_proj_weight,
|
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
||||||
self.dt_proj_weight,
|
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
|
||||||
self.out_proj_weight,
|
if causal_conv1d_fn is None:
|
||||||
self.out_proj_bias,
|
x = self.act(self.conv1d(x)[..., :seqlen])
|
||||||
|
else:
|
||||||
|
assert self.activation in ["silu", "swish"]
|
||||||
|
x = causal_conv1d_fn(
|
||||||
|
x=x,
|
||||||
|
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
||||||
|
bias=self.conv1d.bias,
|
||||||
|
activation=self.activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We're careful here about the layout, to avoid extra transposes.
|
||||||
|
# We want dt to have d as the slowest moving dimension
|
||||||
|
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
||||||
|
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
||||||
|
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||||
|
dt = self.dt_proj.weight @ dt.t()
|
||||||
|
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
||||||
|
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
||||||
|
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
||||||
|
assert self.activation in ["silu", "swish"]
|
||||||
|
y, last_ssm_state = selective_scan_fn(
|
||||||
|
x,
|
||||||
|
dt,
|
||||||
self.negA,
|
self.negA,
|
||||||
None,
|
B,
|
||||||
None,
|
C,
|
||||||
self.D.float(),
|
self.D.float(),
|
||||||
delta_bias=self.dt_proj_bias.float(),
|
z=z,
|
||||||
|
delta_bias=self.dt_proj.bias.float(),
|
||||||
delta_softplus=True,
|
delta_softplus=True,
|
||||||
|
return_last_state=True, # ssm_state is not None,
|
||||||
)
|
)
|
||||||
return attn_outputs
|
y = rearrange(y, "b d l -> b l d")
|
||||||
|
attn_outputs = self.out_proj(y)
|
||||||
|
|
||||||
|
return attn_outputs, conv_state, last_ssm_state
|
||||||
|
|
||||||
|
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
|
||||||
|
assert self.layer_idx is not None
|
||||||
|
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
||||||
|
return conv_state, ssm_state
|
||||||
|
|
||||||
|
def step(self, hidden_states, conv_state, ssm_state):
|
||||||
|
dtype = hidden_states.dtype
|
||||||
|
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
||||||
|
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
||||||
|
x, z = xz.chunk(2, dim=-1) # (B D)
|
||||||
|
|
||||||
|
# Conv step
|
||||||
|
if causal_conv1d_update is None:
|
||||||
|
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
||||||
|
conv_state[:, :, -1] = x
|
||||||
|
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
||||||
|
if self.conv1d.bias is not None:
|
||||||
|
x = x + self.conv1d.bias
|
||||||
|
x = self.act(x).to(dtype=dtype)
|
||||||
|
else:
|
||||||
|
x = causal_conv1d_update(
|
||||||
|
x,
|
||||||
|
conv_state,
|
||||||
|
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
||||||
|
self.conv1d.bias,
|
||||||
|
self.activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
||||||
|
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||||
|
# Don't add dt_bias here
|
||||||
|
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
|
||||||
|
|
||||||
|
# SSM step
|
||||||
|
# Discretize A and B
|
||||||
|
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
|
||||||
|
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, self.negA))
|
||||||
|
dB = torch.einsum("bd,bn->bdn", dt, B)
|
||||||
|
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
|
||||||
|
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
|
||||||
|
y = y + self.D.to(dtype) * x
|
||||||
|
y = y * self.act(z) # (B D)
|
||||||
|
|
||||||
|
out = self.out_proj(y)
|
||||||
|
return out.unsqueeze(1), conv_state, ssm_state
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, layer_id, config, weights):
|
||||||
@ -89,30 +187,35 @@ class ResidualBlock(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
):
|
inference_params: Optional[Any] = None,
|
||||||
|
):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states, _ = self.layer_norm(hidden_states.squeeze(0))
|
shape = hidden_states.shape
|
||||||
hidden_states = residual + self.mamba_block(hidden_states.unsqueeze(0))
|
hidden_states, _ = self.layer_norm(hidden_states.view(-1, shape[-1]))
|
||||||
return hidden_states
|
hidden_states, _conv_state, last_ssm_state = self.mamba_block(hidden_states.view(*shape), inference_params)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
return hidden_states, _conv_state, last_ssm_state
|
||||||
|
|
||||||
class MambaModel(nn.Module):
|
class MambaModel(nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_rank = weights.process_group.rank()
|
|
||||||
self.tp_world_size = weights.process_group.size()
|
|
||||||
prefix = "backbone"
|
prefix = "backbone"
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
|
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
|
||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList(
|
||||||
[ResidualBlock(f"{prefix}.layers.{i}", config, weights) for i in range(config.n_layer)]
|
[ResidualBlock(f"{prefix}.layers.{i}", config, weights) for i in range(config.n_layer)]
|
||||||
)
|
)
|
||||||
self.norm_f = FastRMSNorm.load(f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon)
|
self.norm_f = FastRMSNorm.load(f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon)
|
||||||
self.lm_head = TensorParallelColumnLinear.load(config, f"{prefix}.embedding", weights, False)
|
self.lm_head = FastLinear.load(config, f"{prefix}.embedding", weights, bias=False)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
def forward(self, input_ids: torch.Tensor):
|
def forward(self, input_ids: torch.Tensor, inference_params=None):
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
print("Input ids: ", input_ids)
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
hidden_states = block(hidden_states)
|
hidden_states, _conv_state, last_ssm_state = block(hidden_states, inference_params)
|
||||||
|
# inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (_conv_state, last_ssm_state)
|
||||||
|
|
||||||
final_hidden_states, _ = self.norm_f(hidden_states.squeeze(0))
|
|
||||||
return self.lm_head(final_hidden_states.unsqueeze(0)), input_ids
|
shape = hidden_states.shape
|
||||||
|
final_hidden_states, _ = self.norm_f(hidden_states.view(-1, shape[-1]))
|
||||||
|
return self.lm_head(final_hidden_states.view(*shape)), input_ids, inference_params
|
||||||
|
@ -1,11 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from text_generation_server.models import CausalLM
|
|
||||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
|
||||||
from text_generation_server.models.custom_modeling.mamba_modeling import (
|
from text_generation_server.models.custom_modeling.mamba_modeling import (
|
||||||
MambaConfig,
|
MambaConfig,
|
||||||
)
|
)
|
||||||
@ -15,11 +11,10 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel
|
from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from typing import Any, List, Optional, Tuple, Type
|
from typing import Any, List, Optional, Tuple, Type, Dict
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
Batch,
|
Batch,
|
||||||
Tokens,
|
Tokens,
|
||||||
@ -27,15 +22,55 @@ from text_generation_server.models.types import (
|
|||||||
GeneratedText,
|
GeneratedText,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
from mamba_ssm.utils.generation import InferenceParams
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MambaBatch(Batch):
|
||||||
|
batch_id: int
|
||||||
|
requests: List[generate_pb2.Request]
|
||||||
|
requests_idx_mapping: Dict[int, int]
|
||||||
|
|
||||||
class MambaCausalLMBatch(CausalLMBatch):
|
# Decoder values
|
||||||
|
input_ids: torch.Tensor
|
||||||
past_input_ids: Optional[torch.Tensor]
|
past_input_ids: Optional[torch.Tensor]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
# All tokens
|
||||||
super().__init__(*args, **kwargs)
|
all_input_ids: List[torch.Tensor]
|
||||||
self.past_input_ids = None
|
|
||||||
|
|
||||||
|
# Lengths of all generations present in the batch
|
||||||
|
input_lengths: List[int]
|
||||||
|
prefix_offsets: List[int]
|
||||||
|
read_offsets: List[int]
|
||||||
|
|
||||||
|
# Generation helpers
|
||||||
|
next_token_choosers: List[NextTokenChooser]
|
||||||
|
stopping_criterias: List[StoppingCriteria]
|
||||||
|
top_n_tokens: List[int]
|
||||||
|
top_n_tokens_tensor: torch.Tensor
|
||||||
|
|
||||||
|
# Metadata used for padding
|
||||||
|
max_input_length: int
|
||||||
|
padding_right_offset: int
|
||||||
|
|
||||||
|
# Maximum number of tokens this batch will grow to
|
||||||
|
max_tokens: int
|
||||||
|
|
||||||
|
# Past metadata
|
||||||
|
keys_head_dim_last: bool = True
|
||||||
|
|
||||||
|
# Inference params
|
||||||
|
inference_params: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||||
|
return generate_pb2.CachedBatch(
|
||||||
|
id=self.batch_id,
|
||||||
|
request_ids=[r.id for r in self.requests],
|
||||||
|
size=len(self),
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
@ -43,11 +78,256 @@ class MambaCausalLMBatch(CausalLMBatch):
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "CausalLMBatch":
|
) -> "MambaBatch":
|
||||||
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
|
inputs = []
|
||||||
batch.keys_head_dim_last = False
|
next_token_choosers = []
|
||||||
return batch
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
|
prefix_offsets = []
|
||||||
|
read_offsets = []
|
||||||
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
|
# Parse batch
|
||||||
|
max_truncation = 0
|
||||||
|
padding_right_offset = 0
|
||||||
|
max_decode_tokens = 0
|
||||||
|
for i, r in enumerate(pb.requests):
|
||||||
|
requests_idx_mapping[r.id] = i
|
||||||
|
inputs.append(r.inputs)
|
||||||
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
|
r.stopping_parameters, tokenizer
|
||||||
|
)
|
||||||
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
top_n_tokens.append(r.top_n_tokens)
|
||||||
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||||
|
padding_right_offset = max(
|
||||||
|
padding_right_offset, stopping_criteria.max_new_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenized_inputs = tokenizer(
|
||||||
|
inputs,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
return_token_type_ids=False,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_truncation,
|
||||||
|
).to(device)
|
||||||
|
for _ in pb.requests:
|
||||||
|
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||||
|
prefix_offsets.append(input_len - 5)
|
||||||
|
read_offsets.append(input_len)
|
||||||
|
|
||||||
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||||
|
max_input_length = input_lengths.max()
|
||||||
|
input_ids = tokenized_inputs["input_ids"]
|
||||||
|
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
||||||
|
top_n_tokens_tensor = torch.tensor(
|
||||||
|
top_n_tokens, device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||||
|
return cls(
|
||||||
|
batch_id=pb.id,
|
||||||
|
requests=pb.requests,
|
||||||
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
|
input_ids=input_ids,
|
||||||
|
past_input_ids=None,
|
||||||
|
all_input_ids=list(all_input_ids),
|
||||||
|
input_lengths=input_lengths.tolist(),
|
||||||
|
prefix_offsets=prefix_offsets,
|
||||||
|
read_offsets=read_offsets,
|
||||||
|
next_token_choosers=next_token_choosers,
|
||||||
|
stopping_criterias=stopping_criterias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
|
max_input_length=max_input_length.item(),
|
||||||
|
padding_right_offset=padding_right_offset,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]:
|
||||||
|
if len(request_ids) == 0:
|
||||||
|
raise ValueError("Batch must have at least one request")
|
||||||
|
if len(request_ids) == len(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
keep_indices = []
|
||||||
|
|
||||||
|
# New values after filtering
|
||||||
|
requests_idx_mapping = {}
|
||||||
|
requests = []
|
||||||
|
input_lengths = []
|
||||||
|
prefix_offsets = []
|
||||||
|
read_offsets = []
|
||||||
|
all_input_ids = []
|
||||||
|
max_input_length = 0
|
||||||
|
|
||||||
|
next_token_choosers = []
|
||||||
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
|
|
||||||
|
total_remaining_decode_tokens = 0
|
||||||
|
new_padding_right_offset = 0
|
||||||
|
|
||||||
|
for i, request_id in enumerate(request_ids):
|
||||||
|
idx = self.requests_idx_mapping[request_id]
|
||||||
|
requests_idx_mapping[request_id] = i
|
||||||
|
keep_indices.append(idx)
|
||||||
|
|
||||||
|
requests.append(self.requests[idx])
|
||||||
|
prefix_offsets.append(self.prefix_offsets[idx])
|
||||||
|
read_offsets.append(self.read_offsets[idx])
|
||||||
|
all_input_ids.append(self.all_input_ids[idx])
|
||||||
|
|
||||||
|
request_input_length = self.input_lengths[idx]
|
||||||
|
input_lengths.append(request_input_length)
|
||||||
|
max_input_length = max(max_input_length, request_input_length)
|
||||||
|
|
||||||
|
next_token_choosers.append(self.next_token_choosers[idx])
|
||||||
|
stopping_criteria = self.stopping_criterias[idx]
|
||||||
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
top_n_tokens.append(self.top_n_tokens[idx])
|
||||||
|
remaining_decode_tokens = (
|
||||||
|
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||||
|
)
|
||||||
|
total_remaining_decode_tokens += remaining_decode_tokens
|
||||||
|
new_padding_right_offset = max(
|
||||||
|
new_padding_right_offset, remaining_decode_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
||||||
|
input_ids = self.input_ids[keep_indices]
|
||||||
|
position_ids = self.position_ids[keep_indices]
|
||||||
|
|
||||||
|
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
|
||||||
|
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
|
||||||
|
|
||||||
|
self.requests = requests
|
||||||
|
self.requests_idx_mapping = requests_idx_mapping
|
||||||
|
self.input_ids = input_ids
|
||||||
|
self.all_input_ids = all_input_ids
|
||||||
|
self.input_lengths = input_lengths
|
||||||
|
self.prefix_offsets = prefix_offsets
|
||||||
|
self.read_offsets = read_offsets
|
||||||
|
self.next_token_choosers = next_token_choosers
|
||||||
|
self.stopping_criterias = stopping_criterias
|
||||||
|
self.top_n_tokens = top_n_tokens
|
||||||
|
self.top_n_tokens_tensor = top_n_tokens_tensor
|
||||||
|
self.max_input_length = max_input_length
|
||||||
|
self.padding_right_offset = new_padding_right_offset
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch":
|
||||||
|
# Used for padding
|
||||||
|
total_batch_size = 0
|
||||||
|
max_input_length = 0
|
||||||
|
padding_right_offset = 0
|
||||||
|
for batch in batches:
|
||||||
|
total_batch_size += len(batch)
|
||||||
|
max_input_length = max(max_input_length, batch.max_input_length)
|
||||||
|
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
|
||||||
|
|
||||||
|
# Batch attributes
|
||||||
|
requests = []
|
||||||
|
requests_idx_mapping = {}
|
||||||
|
input_lengths = []
|
||||||
|
prefix_offsets = []
|
||||||
|
read_offsets = []
|
||||||
|
all_input_ids = []
|
||||||
|
next_token_choosers = []
|
||||||
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
|
max_tokens = 0
|
||||||
|
|
||||||
|
# Batch tensors
|
||||||
|
input_ids = None
|
||||||
|
attention_mask = None
|
||||||
|
position_ids = None
|
||||||
|
past_key_values = []
|
||||||
|
top_n_tokens_tensor = None
|
||||||
|
|
||||||
|
# Used for slicing correctly inside the tensors
|
||||||
|
# Equivalent to a cumsum on batch sizes
|
||||||
|
start_index = 0
|
||||||
|
for i, batch in enumerate(batches):
|
||||||
|
requests.extend(batch.requests)
|
||||||
|
input_lengths.extend(batch.input_lengths)
|
||||||
|
prefix_offsets.extend(batch.prefix_offsets)
|
||||||
|
read_offsets.extend(batch.read_offsets)
|
||||||
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
top_n_tokens.extend(batch.top_n_tokens)
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
requests_idx_mapping = batch.requests_idx_mapping
|
||||||
|
else:
|
||||||
|
# We need to offset the mapping for each batch by the cumulative batch size
|
||||||
|
for k, v in batch.requests_idx_mapping.items():
|
||||||
|
requests_idx_mapping[k] = v + start_index
|
||||||
|
|
||||||
|
# Slicing end index for this batch
|
||||||
|
end_index = start_index + len(batch)
|
||||||
|
|
||||||
|
# We only concatenate batches that did at least one step
|
||||||
|
if batch.past_key_values is None:
|
||||||
|
raise ValueError("only concatenate prefilled batches")
|
||||||
|
|
||||||
|
# Create empty tensor
|
||||||
|
# input_ids is always of shape [batch_size, 1]
|
||||||
|
# We do not need to pad it
|
||||||
|
if input_ids is None:
|
||||||
|
input_ids = batch.input_ids.new_empty((total_batch_size, 1))
|
||||||
|
# Copy to correct indices
|
||||||
|
input_ids[start_index:end_index] = batch.input_ids
|
||||||
|
|
||||||
|
# Create padded tensor
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = batch.attention_mask.new_zeros(
|
||||||
|
(total_batch_size, max_input_length + padding_right_offset),
|
||||||
|
)
|
||||||
|
|
||||||
|
if top_n_tokens_tensor is None:
|
||||||
|
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||||
|
total_batch_size,
|
||||||
|
)
|
||||||
|
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||||
|
|
||||||
|
# Add eventual padding tokens that were added while concatenating
|
||||||
|
max_tokens += batch.max_tokens + (
|
||||||
|
max_input_length - batch.max_input_length
|
||||||
|
) * len(batch)
|
||||||
|
|
||||||
|
start_index = end_index
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
batch_id=batches[0].batch_id,
|
||||||
|
requests=requests,
|
||||||
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
all_input_ids=all_input_ids,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
prefix_offsets=prefix_offsets,
|
||||||
|
read_offsets=read_offsets,
|
||||||
|
next_token_choosers=next_token_choosers,
|
||||||
|
stopping_criterias=stopping_criterias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
|
max_input_length=max_input_length,
|
||||||
|
padding_right_offset=padding_right_offset,
|
||||||
|
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.requests)
|
||||||
|
|
||||||
class Mamba(Model):
|
class Mamba(Model):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -99,8 +379,8 @@ class Mamba(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
def batch_type(self) -> Type[MambaBatch]:
|
||||||
return MambaCausalLMBatch
|
return MambaBatch
|
||||||
|
|
||||||
def warmup(self, batch) -> Optional[int]:
|
def warmup(self, batch) -> Optional[int]:
|
||||||
# TODO: implement warmup for Mamba if needed
|
# TODO: implement warmup for Mamba if needed
|
||||||
@ -119,10 +399,50 @@ class Mamba(Model):
|
|||||||
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
||||||
start = time.time_ns()
|
start = time.time_ns()
|
||||||
|
|
||||||
input_ids = batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
|
input_ids = batch.input_ids # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
|
||||||
|
|
||||||
logits, past_input_ids = self.model(input_ids)[:2]
|
batch_size = input_ids.shape[0]
|
||||||
|
max_seqlen = input_ids.shape[1]
|
||||||
|
dtype = input_ids.dtype
|
||||||
|
|
||||||
|
# Inference params
|
||||||
|
seqlen_og = max_seqlen
|
||||||
|
inf_cache = {}
|
||||||
|
lengths_per_sample = torch.ones(batch_size, dtype=torch.int32, device=input_ids.device) * max_seqlen
|
||||||
|
|
||||||
|
if batch.inference_params is None:
|
||||||
|
inference_params = InferenceParams(
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
max_batch_size=batch_size,
|
||||||
|
seqlen_offset=seqlen_og,
|
||||||
|
key_value_memory_dict=inf_cache,
|
||||||
|
lengths_per_sample=lengths_per_sample,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Allocate inference cache
|
||||||
|
for res_block in self.model.blocks:
|
||||||
|
block = res_block.mamba_block
|
||||||
|
conv_state = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
self.model.config.d_model * self.model.config.expand,
|
||||||
|
self.model.config.d_conv,
|
||||||
|
device=block.conv1d.weight.device,
|
||||||
|
dtype=block.conv1d.weight.dtype,
|
||||||
|
)
|
||||||
|
ssm_state = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
self.model.config.d_model * self.model.config.expand,
|
||||||
|
self.model.config.d_state,
|
||||||
|
device=block.dt_proj.weight.device,
|
||||||
|
dtype=block.dt_proj.weight.dtype,
|
||||||
|
)
|
||||||
|
inference_params.key_value_memory_dict[block.layer_idx] = (conv_state, ssm_state)
|
||||||
|
batch.inference_params = inference_params
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
logits, past_input_ids, new_inference_params = self.model(input_ids, batch.inference_params)
|
||||||
|
|
||||||
|
batch.inference_params = new_inference_params
|
||||||
# Results
|
# Results
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
stopped = True
|
stopped = True
|
||||||
@ -272,6 +592,7 @@ class Mamba(Model):
|
|||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
|
batch.input_ids[i, 0] = 0 # next_token_id
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
batch.input_lengths[i] = new_input_length
|
batch.input_lengths[i] = new_input_length
|
||||||
batch.prefix_offsets[i] = prefix_offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
@ -284,6 +605,10 @@ class Mamba(Model):
|
|||||||
decode_ns = time.time_ns() - start_decode
|
decode_ns = time.time_ns() - start_decode
|
||||||
return generations, None, (forward_ns, decode_ns)
|
return generations, None, (forward_ns, decode_ns)
|
||||||
|
|
||||||
|
# Slice unused values from prefill
|
||||||
|
batch.input_ids = batch.input_ids[:, :1]
|
||||||
|
|
||||||
|
|
||||||
batch.past_input_ids = past_input_ids
|
batch.past_input_ids = past_input_ids
|
||||||
|
|
||||||
forward_ns = start_decode - start
|
forward_ns = start_decode - start
|
||||||
|
Loading…
Reference in New Issue
Block a user