text-generation-inference/server/text_generation_server/models/custom_modeling/mamba_modeling.py

121 lines
4.3 KiB
Python
Raw Normal View History

2024-01-23 01:37:09 +00:00
import torch
import torch.distributed
2024-01-29 21:54:23 +00:00
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
2024-01-23 01:37:09 +00:00
from torch import nn
from typing import Optional, List, Tuple, Any
from transformers.configuration_utils import PretrainedConfig
import torch.nn.functional as F
2024-01-23 01:37:09 +00:00
from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
FastRMSNorm,
2024-01-23 01:37:09 +00:00
)
class MambaConfig(PretrainedConfig):
def __init__(
self,
vocab_size=50280,
d_model=768,
2024-01-23 01:37:09 +00:00
n_layer=32,
layer_norm_epsilon=1e-5,
tie_word_embeddings=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
**kwargs,
):
self.vocab_size = vocab_size
self.n_layer = n_layer
self.layer_norm_epsilon = layer_norm_epsilon
self.d_model = d_model
self.d_inner = d_model * 2
self.d_conv = 4
2024-01-23 01:37:09 +00:00
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class MambaBlock(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.in_proj = TensorParallelColumnLinear.load(
config=config, prefix=f"{prefix}.in_proj", weights=weights, bias=False
2024-01-23 01:37:09 +00:00
)
# helper for loading weights
self.load_weights(prefix, weights)
def load_weights(self, prefix, weights):
weight_names = ["x_proj.weight", "dt_proj.weight", "dt_proj.bias",
"out_proj.weight", "in_proj.weight",
"conv1d.weight", "conv1d.bias", "A_log", "D"]
for name in weight_names:
param_name = name.replace('.', '_')
setattr(self, param_name, nn.Parameter(weights.get_tensor(f"{prefix}.{name}")))
self.out_proj_bias = None
self.negA = -torch.exp(self.A_log.float())
2024-01-29 21:54:23 +00:00
def forward(self, hidden_states: torch.Tensor):
projected_states = self.in_proj(hidden_states).transpose(1,2)
2024-01-29 21:54:23 +00:00
# conv1d, ssm, and selective_scan are all fused into one kernel
attn_outputs = mamba_inner_fn(
projected_states,
self.conv1d_weight,
self.conv1d_bias,
2024-01-29 21:54:23 +00:00
self.x_proj_weight,
self.dt_proj_weight,
self.out_proj_weight,
self.out_proj_bias,
self.negA,
2024-01-29 21:54:23 +00:00
None,
None,
self.D.float(),
delta_bias=self.dt_proj_bias.float(),
delta_softplus=True,
)
return attn_outputs
2024-01-23 01:37:09 +00:00
class ResidualBlock(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
self.mamba_block = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights)
self.layer_norm = FastRMSNorm.load(prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon)
2024-01-23 01:37:09 +00:00
def forward(
self,
hidden_states: torch.Tensor,
):
2024-01-23 01:37:09 +00:00
residual = hidden_states
2024-01-31 10:28:58 +00:00
shape = hidden_states.shape
hidden_states, _ = self.layer_norm(hidden_states.view(-1, shape[-1]))
hidden_states = residual + self.mamba_block(hidden_states.view(*shape))
return hidden_states
2024-01-23 01:37:09 +00:00
class MambaModel(nn.Module):
def __init__(self, config, weights):
super().__init__()
self.tp_rank = weights.process_group.rank()
self.tp_world_size = weights.process_group.size()
prefix = "backbone"
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
self.blocks = nn.ModuleList(
[ResidualBlock(f"{prefix}.layers.{i}", config, weights) for i in range(config.n_layer)]
2024-01-23 01:37:09 +00:00
)
self.norm_f = FastRMSNorm.load(f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon)
self.lm_head = TensorParallelColumnLinear.load(config, f"{prefix}.embedding", weights, False)
2024-01-23 01:37:09 +00:00
def forward(self, input_ids: torch.Tensor):
2024-01-23 01:37:09 +00:00
hidden_states = self.embed_tokens(input_ids)
for block in self.blocks:
hidden_states = block(hidden_states)
2024-01-23 01:37:09 +00:00
2024-01-31 10:28:58 +00:00
shape = hidden_states.shape
final_hidden_states, _ = self.norm_f(hidden_states.view(-1, shape[-1]))
return self.lm_head(final_hidden_states.view(*shape)), input_ids