mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 22:32:07 +00:00
# What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
233 lines
8.3 KiB
Python
233 lines
8.3 KiB
Python
import torch
|
|
import torch.distributed
|
|
|
|
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
|
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
|
|
from torch import nn
|
|
from typing import Optional, Tuple, Any
|
|
from transformers.configuration_utils import PretrainedConfig
|
|
import torch.nn.functional as F
|
|
|
|
from text_generation_server.utils.layers import (
|
|
SpeculativeHead,
|
|
TensorParallelEmbedding,
|
|
FastRMSNorm,
|
|
FastLinear,
|
|
)
|
|
|
|
from einops import rearrange
|
|
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
|
import math
|
|
from dataclasses import dataclass
|
|
|
|
|
|
@dataclass
|
|
class InferenceParams:
|
|
"""Inference parameters that are passed to the main model in order
|
|
to efficienly calculate and store the context during inference."""
|
|
|
|
max_seqlen: int
|
|
max_batch_size: int
|
|
conv_states: torch.Tensor
|
|
ssm_states: torch.Tensor
|
|
seqlen_offset: int
|
|
|
|
|
|
class MambaConfig(PretrainedConfig):
|
|
def __init__(
|
|
self,
|
|
vocab_size=50280,
|
|
d_model=768,
|
|
d_state=16,
|
|
n_layer=32,
|
|
layer_norm_epsilon=1e-5,
|
|
tie_word_embeddings=False,
|
|
pad_token_id=0,
|
|
bos_token_id=1,
|
|
eos_token_id=2,
|
|
expand=2,
|
|
dt_rank="auto",
|
|
**kwargs,
|
|
):
|
|
self.vocab_size = vocab_size
|
|
self.n_layer = n_layer
|
|
self.layer_norm_epsilon = layer_norm_epsilon
|
|
self.d_model = d_model
|
|
self.d_inner = d_model * 2
|
|
self.d_conv = 4
|
|
self.d_state = d_state
|
|
self.expand = expand
|
|
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
|
|
|
super().__init__(
|
|
pad_token_id=pad_token_id,
|
|
bos_token_id=bos_token_id,
|
|
eos_token_id=eos_token_id,
|
|
tie_word_embeddings=tie_word_embeddings,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
class MambaBlock(nn.Module):
|
|
def __init__(self, prefix, config, weights, layer_id):
|
|
super().__init__()
|
|
self.layer_id = layer_id
|
|
self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False)
|
|
self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False)
|
|
self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True)
|
|
self.dt_proj_no_bias = FastLinear.load(
|
|
config, f"{prefix}.dt_proj", weights, bias=False
|
|
)
|
|
self.out_proj = FastLinear.load(
|
|
config, f"{prefix}.out_proj", weights, bias=False
|
|
)
|
|
self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True)
|
|
self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float())
|
|
self.D = weights.get_tensor(f"{prefix}.D")
|
|
self.activation = "silu"
|
|
self.dt_rank = config.dt_rank
|
|
self.d_state = config.d_state
|
|
self.d_conv = config.d_conv
|
|
self.act = nn.SiLU()
|
|
|
|
# inference_params
|
|
def forward(self, hidden_states: torch.Tensor, inference_params=None):
|
|
if inference_params.seqlen_offset > 0:
|
|
conv_state = inference_params.conv_states[self.layer_id]
|
|
ssm_state = inference_params.ssm_states[self.layer_id]
|
|
out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state)
|
|
return out, conv_state, ssm_state
|
|
|
|
_, seqlen, _ = hidden_states.shape
|
|
projected_states = self.in_proj(hidden_states).transpose(1, 2)
|
|
# assert projected_states.shape == [batch_size, 2 * dstate, seqlen], f"{projected_states.shape} [{batch_size}, {dstate}, {seqlen}]"
|
|
x, z = projected_states.chunk(2, dim=1)
|
|
conv_state = F.pad(x, (self.d_conv - seqlen, 0))
|
|
x = causal_conv1d_fn(
|
|
x=x,
|
|
weight=self.conv1d.weight.squeeze(1),
|
|
bias=self.conv1d.bias,
|
|
activation=self.activation,
|
|
)
|
|
|
|
# We're careful here about the layout, to avoid extra transposes.
|
|
# We want dt to have d as the slowest moving dimension
|
|
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
|
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
|
dt, B, C = torch.split(
|
|
x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
|
|
)
|
|
dt = self.dt_proj.weight @ dt.t()
|
|
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
|
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
|
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
|
y, last_state = selective_scan_fn(
|
|
x,
|
|
dt,
|
|
self.negA,
|
|
B,
|
|
C,
|
|
self.D.float(),
|
|
z=z,
|
|
delta_bias=self.dt_proj.bias.float(),
|
|
delta_softplus=True,
|
|
return_last_state=True,
|
|
)
|
|
y = rearrange(y, "b d l -> b l d")
|
|
attn_outputs = self.out_proj(y)
|
|
return attn_outputs, conv_state, last_state
|
|
|
|
def step(self, hidden_states, conv_state, ssm_state):
|
|
xz = self.in_proj(hidden_states.squeeze(1))
|
|
x, z = xz.chunk(2, dim=-1) # (B D)
|
|
x = causal_conv1d_update(
|
|
x,
|
|
conv_state,
|
|
self.conv1d.weight.squeeze(1),
|
|
self.conv1d.bias,
|
|
self.activation,
|
|
)
|
|
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
|
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
|
dt = F.linear(dt, self.dt_proj.weight)
|
|
A = self.negA
|
|
y = selective_state_update(
|
|
ssm_state,
|
|
x,
|
|
dt,
|
|
A,
|
|
B,
|
|
C,
|
|
self.D,
|
|
z=z,
|
|
dt_bias=self.dt_proj.bias,
|
|
dt_softplus=True,
|
|
)
|
|
out = self.out_proj(y)
|
|
return out.unsqueeze(1), conv_state.clone(), ssm_state.clone()
|
|
|
|
|
|
class ResidualBlock(nn.Module):
|
|
def __init__(self, prefix, config, weights, layer_id):
|
|
super().__init__()
|
|
self.mamba_block = MambaBlock(
|
|
prefix=f"{prefix}.mixer", config=config, weights=weights, layer_id=layer_id
|
|
)
|
|
self.layer_norm = FastRMSNorm.load(
|
|
prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_epsilon
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None,
|
|
inference_params: Optional[Any] = None,
|
|
):
|
|
residual = (hidden_states + residual) if residual is not None else hidden_states
|
|
shape = residual.shape
|
|
hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1]))
|
|
hidden_states, conv_state, last_ssm_state = self.mamba_block(
|
|
hidden_states.view(*shape), inference_params
|
|
)
|
|
return hidden_states, residual, conv_state, last_ssm_state
|
|
|
|
|
|
class MambaModel(nn.Module):
|
|
def __init__(self, config, weights):
|
|
super().__init__()
|
|
prefix = "backbone"
|
|
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
|
|
self.blocks = nn.ModuleList(
|
|
[
|
|
ResidualBlock(f"{prefix}.layers.{i}", config, weights, layer_id=i)
|
|
for i in range(config.n_layer)
|
|
]
|
|
)
|
|
self.norm_f = FastRMSNorm.load(
|
|
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
|
|
)
|
|
self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights)
|
|
self.config = config
|
|
|
|
def forward(
|
|
self, input_ids: torch.Tensor, inference_params=None, residual=None
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
hidden_states = self.embed_tokens(input_ids)
|
|
for i, block in enumerate(self.blocks):
|
|
hidden_states, residual, conv_state, ssm_state = block(
|
|
hidden_states, residual, inference_params
|
|
)
|
|
inference_params.conv_states[i].copy_(conv_state)
|
|
inference_params.ssm_states[i].copy_(ssm_state)
|
|
|
|
hidden_states = (
|
|
hidden_states + residual if residual is not None else hidden_states
|
|
)
|
|
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
|
|
hidden_states = hidden_states.view(residual.shape)
|
|
logits, speculative_logits = self.lm_head(hidden_states)
|
|
|
|
# update the offset for the next inference using these params
|
|
inference_params.seqlen_offset += input_ids.size(1)
|
|
return logits, speculative_logits
|