2024-01-23 01:37:09 +00:00
|
|
|
import torch
|
|
|
|
import torch.distributed
|
|
|
|
|
2024-01-29 21:54:23 +00:00
|
|
|
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
2024-01-23 01:37:09 +00:00
|
|
|
from torch import nn
|
|
|
|
from typing import Optional, List, Tuple, Any
|
|
|
|
from transformers.configuration_utils import PretrainedConfig
|
2024-01-25 01:55:12 +00:00
|
|
|
import torch.nn.functional as F
|
2024-01-23 01:37:09 +00:00
|
|
|
|
|
|
|
from text_generation_server.utils.layers import (
|
|
|
|
TensorParallelColumnLinear,
|
|
|
|
TensorParallelEmbedding,
|
2024-01-30 18:53:28 +00:00
|
|
|
FastRMSNorm,
|
2024-01-23 01:37:09 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
class MambaConfig(PretrainedConfig):
|
|
|
|
def __init__(
|
|
|
|
self,
|
2024-01-25 01:55:12 +00:00
|
|
|
vocab_size=50280,
|
2024-01-25 22:07:37 +00:00
|
|
|
d_model=768,
|
2024-01-23 01:37:09 +00:00
|
|
|
n_layer=32,
|
|
|
|
layer_norm_epsilon=1e-5,
|
|
|
|
tie_word_embeddings=False,
|
|
|
|
pad_token_id=0,
|
|
|
|
bos_token_id=1,
|
|
|
|
eos_token_id=2,
|
|
|
|
**kwargs,
|
|
|
|
):
|
|
|
|
self.vocab_size = vocab_size
|
|
|
|
self.n_layer = n_layer
|
|
|
|
self.layer_norm_epsilon = layer_norm_epsilon
|
2024-01-25 22:07:37 +00:00
|
|
|
self.d_model = d_model
|
|
|
|
self.d_inner = d_model * 2
|
|
|
|
self.d_conv = 4
|
2024-01-23 01:37:09 +00:00
|
|
|
|
|
|
|
super().__init__(
|
|
|
|
pad_token_id=pad_token_id,
|
|
|
|
bos_token_id=bos_token_id,
|
|
|
|
eos_token_id=eos_token_id,
|
|
|
|
tie_word_embeddings=tie_word_embeddings,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
|
|
|
|
class MambaBlock(nn.Module):
|
|
|
|
def __init__(self, prefix, config, weights):
|
|
|
|
super().__init__()
|
|
|
|
self.in_proj = TensorParallelColumnLinear.load(
|
2024-01-30 18:53:28 +00:00
|
|
|
config=config, prefix=f"{prefix}.in_proj", weights=weights, bias=False
|
2024-01-23 01:37:09 +00:00
|
|
|
)
|
2024-01-30 18:53:28 +00:00
|
|
|
# helper for loading weights
|
|
|
|
self.load_weights(prefix, weights)
|
|
|
|
|
|
|
|
def load_weights(self, prefix, weights):
|
|
|
|
weight_names = ["x_proj.weight", "dt_proj.weight", "dt_proj.bias",
|
|
|
|
"out_proj.weight", "in_proj.weight",
|
|
|
|
"conv1d.weight", "conv1d.bias", "A_log", "D"]
|
|
|
|
for name in weight_names:
|
|
|
|
param_name = name.replace('.', '_')
|
|
|
|
setattr(self, param_name, nn.Parameter(weights.get_tensor(f"{prefix}.{name}")))
|
|
|
|
self.out_proj_bias = None
|
|
|
|
self.negA = -torch.exp(self.A_log.float())
|
2024-01-29 21:54:23 +00:00
|
|
|
|
2024-01-30 18:53:28 +00:00
|
|
|
def forward(self, hidden_states: torch.Tensor):
|
|
|
|
projected_states = self.in_proj(hidden_states).transpose(1,2)
|
2024-01-29 21:54:23 +00:00
|
|
|
# conv1d, ssm, and selective_scan are all fused into one kernel
|
|
|
|
attn_outputs = mamba_inner_fn(
|
2024-01-30 18:53:28 +00:00
|
|
|
projected_states,
|
|
|
|
self.conv1d_weight,
|
|
|
|
self.conv1d_bias,
|
2024-01-29 21:54:23 +00:00
|
|
|
self.x_proj_weight,
|
|
|
|
self.dt_proj_weight,
|
|
|
|
self.out_proj_weight,
|
|
|
|
self.out_proj_bias,
|
2024-01-30 18:53:28 +00:00
|
|
|
self.negA,
|
2024-01-29 21:54:23 +00:00
|
|
|
None,
|
|
|
|
None,
|
|
|
|
self.D.float(),
|
|
|
|
delta_bias=self.dt_proj_bias.float(),
|
|
|
|
delta_softplus=True,
|
2024-01-25 01:55:12 +00:00
|
|
|
)
|
2024-01-30 18:53:28 +00:00
|
|
|
return attn_outputs
|
2024-01-23 01:37:09 +00:00
|
|
|
|
|
|
|
class ResidualBlock(nn.Module):
|
|
|
|
def __init__(self, layer_id, config, weights):
|
|
|
|
super().__init__()
|
2024-01-30 18:53:28 +00:00
|
|
|
self.mamba_block = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights)
|
|
|
|
self.layer_norm = FastRMSNorm.load(prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon)
|
2024-01-23 01:37:09 +00:00
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
2024-01-30 18:53:28 +00:00
|
|
|
hidden_states: torch.Tensor,
|
2024-01-25 01:55:12 +00:00
|
|
|
):
|
2024-01-23 01:37:09 +00:00
|
|
|
residual = hidden_states
|
2024-01-31 10:28:58 +00:00
|
|
|
shape = hidden_states.shape
|
|
|
|
hidden_states, _ = self.layer_norm(hidden_states.view(-1, shape[-1]))
|
|
|
|
hidden_states = residual + self.mamba_block(hidden_states.view(*shape))
|
2024-01-30 18:53:28 +00:00
|
|
|
return hidden_states
|
2024-01-23 01:37:09 +00:00
|
|
|
|
|
|
|
class MambaModel(nn.Module):
|
|
|
|
def __init__(self, config, weights):
|
|
|
|
super().__init__()
|
|
|
|
self.tp_rank = weights.process_group.rank()
|
|
|
|
self.tp_world_size = weights.process_group.size()
|
2024-01-30 18:53:28 +00:00
|
|
|
prefix = "backbone"
|
2024-01-25 01:55:12 +00:00
|
|
|
|
2024-01-30 18:53:28 +00:00
|
|
|
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
|
|
|
|
self.blocks = nn.ModuleList(
|
|
|
|
[ResidualBlock(f"{prefix}.layers.{i}", config, weights) for i in range(config.n_layer)]
|
2024-01-23 01:37:09 +00:00
|
|
|
)
|
2024-01-30 18:53:28 +00:00
|
|
|
self.norm_f = FastRMSNorm.load(f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon)
|
|
|
|
self.lm_head = TensorParallelColumnLinear.load(config, f"{prefix}.embedding", weights, False)
|
2024-01-23 01:37:09 +00:00
|
|
|
|
2024-01-30 18:53:28 +00:00
|
|
|
def forward(self, input_ids: torch.Tensor):
|
2024-01-23 01:37:09 +00:00
|
|
|
hidden_states = self.embed_tokens(input_ids)
|
2024-01-30 18:53:28 +00:00
|
|
|
for block in self.blocks:
|
|
|
|
hidden_states = block(hidden_states)
|
2024-01-23 01:37:09 +00:00
|
|
|
|
2024-01-31 10:28:58 +00:00
|
|
|
shape = hidden_states.shape
|
|
|
|
final_hidden_states, _ = self.norm_f(hidden_states.view(-1, shape[-1]))
|
|
|
|
return self.lm_head(final_hidden_states.view(*shape)), input_ids
|