fix: start to add caching of previous states

This commit is contained in:
drbh 2024-02-01 05:00:51 +00:00
parent 5b6f9259c1
commit 2d674624a3
2 changed files with 487 additions and 59 deletions

View File

@ -2,8 +2,9 @@ import torch
import torch.distributed
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
from mamba_ssm.utils.generation import InferenceParams
from torch import nn
from typing import Optional, List, Tuple, Any
from typing import Optional, List, Tuple, Any, Dict
from transformers.configuration_utils import PretrainedConfig
import torch.nn.functional as F
@ -11,19 +12,27 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
FastRMSNorm,
FastLinear,
)
from einops import rearrange, repeat
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
import math
class MambaConfig(PretrainedConfig):
def __init__(
self,
vocab_size=50280,
d_model=768,
d_state=16,
n_layer=32,
layer_norm_epsilon=1e-5,
tie_word_embeddings=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
expand=2,
dt_rank="auto",
**kwargs,
):
self.vocab_size = vocab_size
@ -32,6 +41,9 @@ class MambaConfig(PretrainedConfig):
self.d_model = d_model
self.d_inner = d_model * 2
self.d_conv = 4
self.d_state = d_state
self.expand = expand
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
super().__init__(
pad_token_id=pad_token_id,
@ -44,41 +56,127 @@ class MambaConfig(PretrainedConfig):
class MambaBlock(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.in_proj = TensorParallelColumnLinear.load(
config=config, prefix=f"{prefix}.in_proj", weights=weights, bias=False
)
# helper for loading weights
self.load_weights(prefix, weights)
self.layer_idx = int(prefix.split(".")[2])
self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False)
self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False)
self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True)
self.out_proj = FastLinear.load(config, f"{prefix}.out_proj", weights, bias=False)
self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True)
self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float())
self.D = weights.get_tensor(f"{prefix}.D")
self.activation = "silu"
self.dt_rank = config.dt_rank
self.d_state = config.d_state
self.d_conv = config.d_conv
self.act = nn.SiLU()
def load_weights(self, prefix, weights):
weight_names = ["x_proj.weight", "dt_proj.weight", "dt_proj.bias",
"out_proj.weight", "in_proj.weight",
"conv1d.weight", "conv1d.bias", "A_log", "D"]
for name in weight_names:
param_name = name.replace('.', '_')
setattr(self, param_name, nn.Parameter(weights.get_tensor(f"{prefix}.{name}")))
self.out_proj_bias = None
self.negA = -torch.exp(self.A_log.float())
# inference_params
def forward(self, hidden_states: torch.Tensor, inference_params=None):
seqlen = hidden_states.shape[1]
# TODO: use the inference_params to get the previous states when decoding
conv_state, ssm_state = None, None
if inference_params is not None:
if hidden_states.shape[1] == 1:
print("Decoding")
conv_state, ssm_state = self._get_states_from_cache(inference_params, hidden_states.shape[0])
if inference_params.seqlen_offset > 0:
# The states are updated inplace
out, _conv_state, _ssm_state = self.step(hidden_states, conv_state, ssm_state)
# import ipdb; ipdb.set_trace()
return out, _conv_state, _ssm_state
def forward(self, hidden_states: torch.Tensor):
projected_states = self.in_proj(hidden_states).transpose(1,2)
# conv1d, ssm, and selective_scan are all fused into one kernel
attn_outputs = mamba_inner_fn(
projected_states,
self.conv1d_weight,
self.conv1d_bias,
self.x_proj_weight,
self.dt_proj_weight,
self.out_proj_weight,
self.out_proj_bias,
x, z = projected_states.chunk(2, dim=1)
# Compute short convolution
if conv_state is not None:
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
if causal_conv1d_fn is None:
x = self.act(self.conv1d(x)[..., :seqlen])
else:
assert self.activation in ["silu", "swish"]
x = causal_conv1d_fn(
x=x,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
activation=self.activation,
)
# We're careful here about the layout, to avoid extra transposes.
# We want dt to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = self.dt_proj.weight @ dt.t()
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
assert self.activation in ["silu", "swish"]
y, last_ssm_state = selective_scan_fn(
x,
dt,
self.negA,
None,
None,
B,
C,
self.D.float(),
delta_bias=self.dt_proj_bias.float(),
z=z,
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=True, # ssm_state is not None,
)
return attn_outputs
y = rearrange(y, "b d l -> b l d")
attn_outputs = self.out_proj(y)
return attn_outputs, conv_state, last_ssm_state
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
assert self.layer_idx is not None
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
return conv_state, ssm_state
def step(self, hidden_states, conv_state, ssm_state):
dtype = hidden_states.dtype
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
x, z = xz.chunk(2, dim=-1) # (B D)
# Conv step
if causal_conv1d_update is None:
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
conv_state[:, :, -1] = x
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
if self.conv1d.bias is not None:
x = x + self.conv1d.bias
x = self.act(x).to(dtype=dtype)
else:
x = causal_conv1d_update(
x,
conv_state,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.activation,
)
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
# Don't add dt_bias here
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
# SSM step
# Discretize A and B
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, self.negA))
dB = torch.einsum("bd,bn->bdn", dt, B)
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
y = y + self.D.to(dtype) * x
y = y * self.act(z) # (B D)
out = self.out_proj(y)
return out.unsqueeze(1), conv_state, ssm_state
class ResidualBlock(nn.Module):
def __init__(self, layer_id, config, weights):
@ -89,30 +187,35 @@ class ResidualBlock(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
):
inference_params: Optional[Any] = None,
):
residual = hidden_states
hidden_states, _ = self.layer_norm(hidden_states.squeeze(0))
hidden_states = residual + self.mamba_block(hidden_states.unsqueeze(0))
return hidden_states
shape = hidden_states.shape
hidden_states, _ = self.layer_norm(hidden_states.view(-1, shape[-1]))
hidden_states, _conv_state, last_ssm_state = self.mamba_block(hidden_states.view(*shape), inference_params)
hidden_states = residual + hidden_states
return hidden_states, _conv_state, last_ssm_state
class MambaModel(nn.Module):
def __init__(self, config, weights):
super().__init__()
self.tp_rank = weights.process_group.rank()
self.tp_world_size = weights.process_group.size()
prefix = "backbone"
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
self.blocks = nn.ModuleList(
[ResidualBlock(f"{prefix}.layers.{i}", config, weights) for i in range(config.n_layer)]
)
self.norm_f = FastRMSNorm.load(f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon)
self.lm_head = TensorParallelColumnLinear.load(config, f"{prefix}.embedding", weights, False)
self.lm_head = FastLinear.load(config, f"{prefix}.embedding", weights, bias=False)
self.config = config
def forward(self, input_ids: torch.Tensor):
def forward(self, input_ids: torch.Tensor, inference_params=None):
hidden_states = self.embed_tokens(input_ids)
print("Input ids: ", input_ids)
for block in self.blocks:
hidden_states = block(hidden_states)
hidden_states, _conv_state, last_ssm_state = block(hidden_states, inference_params)
# inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (_conv_state, last_ssm_state)
final_hidden_states, _ = self.norm_f(hidden_states.squeeze(0))
return self.lm_head(final_hidden_states.unsqueeze(0)), input_ids
shape = hidden_states.shape
final_hidden_states, _ = self.norm_f(hidden_states.view(-1, shape[-1]))
return self.lm_head(final_hidden_states.view(*shape)), input_ids, inference_params

View File

@ -1,11 +1,7 @@
import torch
import torch.distributed
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from typing import Optional
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.models.custom_modeling.mamba_modeling import (
MambaConfig,
)
@ -15,11 +11,10 @@ from text_generation_server.utils import (
weight_files,
Weights,
)
import time
from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel
from text_generation_server.models import Model
from typing import Any, List, Optional, Tuple, Type
from typing import Any, List, Optional, Tuple, Type, Dict
from text_generation_server.models.types import (
Batch,
Tokens,
@ -27,15 +22,55 @@ from text_generation_server.models.types import (
GeneratedText,
)
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from mamba_ssm.utils.generation import InferenceParams
@dataclass
class MambaBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
requests_idx_mapping: Dict[int, int]
class MambaCausalLMBatch(CausalLMBatch):
# Decoder values
input_ids: torch.Tensor
past_input_ids: Optional[torch.Tensor]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.past_input_ids = None
# All tokens
all_input_ids: List[torch.Tensor]
# Lengths of all generations present in the batch
input_lengths: List[int]
prefix_offsets: List[int]
read_offsets: List[int]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Metadata used for padding
max_input_length: int
padding_right_offset: int
# Maximum number of tokens this batch will grow to
max_tokens: int
# Past metadata
keys_head_dim_last: bool = True
# Inference params
inference_params: Optional[Dict[str, Any]] = None
def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch(
id=self.batch_id,
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
)
@classmethod
def from_pb(
cls,
@ -43,11 +78,256 @@ class MambaCausalLMBatch(CausalLMBatch):
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
batch.keys_head_dim_last = False
return batch
) -> "MambaBatch":
inputs = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
prefix_offsets = []
read_offsets = []
requests_idx_mapping = {}
# Parse batch
max_truncation = 0
padding_right_offset = 0
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
tokenized_inputs = tokenizer(
inputs,
return_tensors="pt",
padding=True,
return_token_type_ids=False,
truncation=True,
max_length=max_truncation,
).to(device)
for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(input_len - 5)
read_offsets.append(input_len)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
input_ids = tokenized_inputs["input_ids"]
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
past_input_ids=None,
all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(),
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
)
def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]:
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
if len(request_ids) == len(self):
return self
keep_indices = []
# New values after filtering
requests_idx_mapping = {}
requests = []
input_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
max_input_length = 0
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
total_remaining_decode_tokens = 0
new_padding_right_offset = 0
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
requests_idx_mapping[request_id] = i
keep_indices.append(idx)
requests.append(self.requests[idx])
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
all_input_ids.append(self.all_input_ids[idx])
request_input_length = self.input_lengths[idx]
input_lengths.append(request_input_length)
max_input_length = max(max_input_length, request_input_length)
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
total_remaining_decode_tokens += remaining_decode_tokens
new_padding_right_offset = max(
new_padding_right_offset, remaining_decode_tokens
)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids = self.input_ids[keep_indices]
position_ids = self.position_ids[keep_indices]
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
self.requests = requests
self.requests_idx_mapping = requests_idx_mapping
self.input_ids = input_ids
self.all_input_ids = all_input_ids
self.input_lengths = input_lengths
self.prefix_offsets = prefix_offsets
self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens
self.top_n_tokens_tensor = top_n_tokens_tensor
self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
return self
@classmethod
def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch":
# Used for padding
total_batch_size = 0
max_input_length = 0
padding_right_offset = 0
for batch in batches:
total_batch_size += len(batch)
max_input_length = max(max_input_length, batch.max_input_length)
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes
requests = []
requests_idx_mapping = {}
input_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
max_tokens = 0
# Batch tensors
input_ids = None
attention_mask = None
position_ids = None
past_key_values = []
top_n_tokens_tensor = None
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + start_index
# Slicing end index for this batch
end_index = start_index + len(batch)
# We only concatenate batches that did at least one step
if batch.past_key_values is None:
raise ValueError("only concatenate prefilled batches")
# Create empty tensor
# input_ids is always of shape [batch_size, 1]
# We do not need to pad it
if input_ids is None:
input_ids = batch.input_ids.new_empty((total_batch_size, 1))
# Copy to correct indices
input_ids[start_index:end_index] = batch.input_ids
# Create padded tensor
if attention_mask is None:
attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_input_length + padding_right_offset),
)
if top_n_tokens_tensor is None:
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length - batch.max_input_length
) * len(batch)
start_index = end_index
return cls(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length,
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
max_tokens=max_tokens,
)
def __len__(self):
return len(self.requests)
class Mamba(Model):
def __init__(
@ -99,8 +379,8 @@ class Mamba(Model):
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return MambaCausalLMBatch
def batch_type(self) -> Type[MambaBatch]:
return MambaBatch
def warmup(self, batch) -> Optional[int]:
# TODO: implement warmup for Mamba if needed
@ -119,10 +399,50 @@ class Mamba(Model):
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
start = time.time_ns()
input_ids = batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
input_ids = batch.input_ids # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
logits, past_input_ids = self.model(input_ids)[:2]
batch_size = input_ids.shape[0]
max_seqlen = input_ids.shape[1]
dtype = input_ids.dtype
# Inference params
seqlen_og = max_seqlen
inf_cache = {}
lengths_per_sample = torch.ones(batch_size, dtype=torch.int32, device=input_ids.device) * max_seqlen
if batch.inference_params is None:
inference_params = InferenceParams(
max_seqlen=max_seqlen,
max_batch_size=batch_size,
seqlen_offset=seqlen_og,
key_value_memory_dict=inf_cache,
lengths_per_sample=lengths_per_sample,
)
# Allocate inference cache
for res_block in self.model.blocks:
block = res_block.mamba_block
conv_state = torch.zeros(
batch_size,
self.model.config.d_model * self.model.config.expand,
self.model.config.d_conv,
device=block.conv1d.weight.device,
dtype=block.conv1d.weight.dtype,
)
ssm_state = torch.zeros(
batch_size,
self.model.config.d_model * self.model.config.expand,
self.model.config.d_state,
device=block.dt_proj.weight.device,
dtype=block.dt_proj.weight.dtype,
)
inference_params.key_value_memory_dict[block.layer_idx] = (conv_state, ssm_state)
batch.inference_params = inference_params
# Forward pass
logits, past_input_ids, new_inference_params = self.model(input_ids, batch.inference_params)
batch.inference_params = new_inference_params
# Results
generations: List[Generation] = []
stopped = True
@ -272,6 +592,7 @@ class Mamba(Model):
generations.append(generation)
# Update values
batch.input_ids[i, 0] = 0 # next_token_id
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length
batch.prefix_offsets[i] = prefix_offset
@ -284,6 +605,10 @@ class Mamba(Model):
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
# Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1]
batch.past_input_ids = past_input_ids
forward_ns = start_decode - start