mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
feat: mvp single inference and explore integration
This commit is contained in:
parent
bcf145733c
commit
35939a28c7
@ -1,57 +1,33 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
import math
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from typing import Optional, List, Tuple, Any
|
from typing import Optional, List, Tuple, Any
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelHead,
|
|
||||||
FastLinear,
|
|
||||||
FastRMSNorm,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MambaConfig(PretrainedConfig):
|
class MambaConfig(PretrainedConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_size=51200,
|
vocab_size=50280,
|
||||||
n_positions=2048,
|
|
||||||
n_embd=2560,
|
|
||||||
n_layer=32,
|
n_layer=32,
|
||||||
n_inner=None,
|
|
||||||
n_head=32,
|
|
||||||
rotary_dim=32,
|
|
||||||
layer_norm_epsilon=1e-5,
|
layer_norm_epsilon=1e-5,
|
||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
pad_vocab_size_multiple=64,
|
|
||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
bos_token_id=1,
|
bos_token_id=1,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
no_bias=False,
|
|
||||||
rms_norm_eps=1e-8,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.n_positions = n_positions
|
|
||||||
self.n_embd = n_embd
|
|
||||||
self.n_layer = n_layer
|
self.n_layer = n_layer
|
||||||
self.n_inner = n_inner
|
|
||||||
self.n_head = n_head
|
|
||||||
self.rotary_dim = rotary_dim
|
|
||||||
|
|
||||||
self.layer_norm_epsilon = layer_norm_epsilon
|
self.layer_norm_epsilon = layer_norm_epsilon
|
||||||
self.tie_word_embeddings = tie_word_embeddings
|
|
||||||
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
|
||||||
self.pad_token_id = pad_token_id
|
|
||||||
self.bos_token_id = bos_token_id
|
|
||||||
self.eos_token_id = eos_token_id
|
|
||||||
self.no_bias = no_bias
|
|
||||||
self.rms_norm_eps = rms_norm_eps
|
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
@ -61,18 +37,29 @@ class MambaConfig(PretrainedConfig):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MambaBlock(nn.Module):
|
class MambaBlock(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# TODO: adjust how weights are loaded
|
|
||||||
|
|
||||||
# conv1d 768*2, 768*2, 4
|
# TODO: use model config to set the dt_rank instead of hardcoding it
|
||||||
self.conv1 = nn.Conv1d(768, 768, 4)
|
d_inner = 768 * 2
|
||||||
# add weight and bias to conv1
|
d_conv = 4
|
||||||
self.conv1.weight = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.weight").transpose(0, 1))
|
self.dt_rank = (768 + 15) // 16
|
||||||
|
|
||||||
|
# TODO: improve how we load the conv1d weights
|
||||||
|
# explore a transposed conv1d that avoids the need for
|
||||||
|
# a transpose during inference
|
||||||
|
self.conv1 = nn.Conv1d(
|
||||||
|
d_inner,
|
||||||
|
d_inner,
|
||||||
|
kernel_size=d_conv,
|
||||||
|
groups=d_inner,
|
||||||
|
padding=d_conv - 1,
|
||||||
|
)
|
||||||
|
self.conv1.weight = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.weight"))
|
||||||
self.conv1.bias = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.bias"))
|
self.conv1.bias = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.bias"))
|
||||||
|
|
||||||
# TODO: load weights in correctly for other operations
|
|
||||||
self.dt_proj = TensorParallelColumnLinear.load(
|
self.dt_proj = TensorParallelColumnLinear.load(
|
||||||
config=config,
|
config=config,
|
||||||
prefix=f"{prefix}.dt_proj",
|
prefix=f"{prefix}.dt_proj",
|
||||||
@ -91,45 +78,136 @@ class MambaBlock(nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
self.A_log = nn.Parameter(torch.randn(config.n_head, config.n_head, config.rotary_dim))
|
self.out_proj = TensorParallelColumnLinear.load(
|
||||||
self.D = nn.Parameter(torch.randn(config.n_head, config.rotary_dim))
|
config=config,
|
||||||
|
prefix=f"{prefix}.out_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
# TODO: improve how we load the weights
|
||||||
self,
|
self.A_log = nn.Parameter(weights.get_tensor(f"{prefix}.A_log"))
|
||||||
hidden_states,
|
self.D = nn.Parameter(weights.get_tensor(f"{prefix}.D"))
|
||||||
past_kv_cache,
|
|
||||||
attention_mask=None,
|
def selective_scan(
|
||||||
|
self, input_tensor, delta, a_tensor, b_tensor, c_tensor, d_tensor
|
||||||
):
|
):
|
||||||
hidden_states_in_proj = self.in_proj(hidden_states)
|
batch_size, sequence_length, input_dim = input_tensor.shape
|
||||||
hidden_states_and_residual = torch.chunk(hidden_states_in_proj, 2, dim=-1)
|
num_cols = a_tensor.shape[1]
|
||||||
|
|
||||||
hs, res = hidden_states_and_residual[0], hidden_states_and_residual[1]
|
# TODO: revisit this math to avoid the transposes when possible
|
||||||
|
# reshape and process delta
|
||||||
|
delta = delta.transpose(1, 2).view((batch_size, input_dim, sequence_length, 1))
|
||||||
|
exp_delta_a = (delta * a_tensor.view((1, input_dim, 1, num_cols))).exp()
|
||||||
|
|
||||||
|
# calc involving delta, b_tensor, and input_tensor
|
||||||
|
delta_b_input = (
|
||||||
|
delta
|
||||||
|
* b_tensor.view((batch_size, 1, sequence_length, num_cols))
|
||||||
|
* input_tensor.transpose(1, 2).view(
|
||||||
|
(batch_size, input_dim, sequence_length, 1)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# init output tensor
|
||||||
|
output_tensor = torch.zeros(
|
||||||
|
(batch_size, input_dim, num_cols),
|
||||||
|
dtype=exp_delta_a.dtype,
|
||||||
|
device=exp_delta_a.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# iterate over sequence_length
|
||||||
|
output_sequence = []
|
||||||
|
for i in range(sequence_length):
|
||||||
|
multiplier = exp_delta_a[:, :, i]
|
||||||
|
output_tensor = (multiplier * output_tensor) + delta_b_input[:, :, i]
|
||||||
|
y = output_tensor.matmul(c_tensor[:, i, :].unsqueeze(2)).squeeze(2)
|
||||||
|
output_sequence.append(y)
|
||||||
|
|
||||||
|
stacked_output = torch.stack(output_sequence, 1)
|
||||||
|
return stacked_output + input_tensor * d_tensor
|
||||||
|
|
||||||
|
def ssm(self, hidden_states):
|
||||||
|
_input_dim, num_cols = self.A_log.shape
|
||||||
|
negative_exponential_a = self.A_log.exp().neg()
|
||||||
|
d_matrix = self.D
|
||||||
|
projected_hidden_states = self.x_proj(hidden_states)
|
||||||
|
|
||||||
|
# narrow operations for delta, b, and c
|
||||||
|
delta = projected_hidden_states.narrow(-1, 0, self.dt_rank)
|
||||||
|
b_matrix = projected_hidden_states.narrow(-1, self.dt_rank, num_cols)
|
||||||
|
c_matrix = projected_hidden_states.narrow(-1, self.dt_rank + num_cols, num_cols)
|
||||||
|
|
||||||
|
# process delta
|
||||||
|
delta = self.dt_proj(delta)
|
||||||
|
delta = torch.log(torch.exp(delta) + 1)
|
||||||
|
|
||||||
|
# apply selective scan
|
||||||
|
selective_scan_output = self.selective_scan(
|
||||||
|
hidden_states, delta, negative_exponential_a, b_matrix, c_matrix, d_matrix
|
||||||
|
)
|
||||||
|
return selective_scan_output
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
sequence_length = hidden_states.shape[1]
|
||||||
|
projected_states = self.in_proj(hidden_states)
|
||||||
|
split_states = torch.chunk(projected_states, 2, dim=-1)
|
||||||
|
transformed_states, residual_states = split_states
|
||||||
|
|
||||||
|
# TODO: avoid the transpose by using a transposed conv1d
|
||||||
|
# apply convolution and narrowing operation
|
||||||
|
conv_output = (
|
||||||
|
self.conv1(transformed_states.transpose(1, 2))
|
||||||
|
.narrow(-1, 0, sequence_length)
|
||||||
|
.transpose(1, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply silu (Swish) activation function
|
||||||
|
activated_transformed = F.silu(conv_output)
|
||||||
|
activated_residual = F.silu(residual_states)
|
||||||
|
|
||||||
|
# Subsequent operations
|
||||||
|
output = self.ssm(activated_transformed)
|
||||||
|
combined_output = output * activated_residual
|
||||||
|
|
||||||
|
return self.out_proj(combined_output)
|
||||||
|
|
||||||
|
# TODO: prefer a more optimized implementation of RMSNorm if possible
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, num_features, eps=1e-8):
|
||||||
|
super().__init__()
|
||||||
|
self.num_features = num_features
|
||||||
|
self.eps = eps
|
||||||
|
self.scale = nn.Parameter(torch.ones(num_features))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
|
||||||
|
x = x / rms
|
||||||
|
return self.scale * x
|
||||||
|
|
||||||
import ipdb; ipdb.set_trace()
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
self.mixer = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights)
|
self.mamba_block = MambaBlock(
|
||||||
self.layer_norm = FastLinear.load(
|
prefix=f"{layer_id}.mixer", config=config, weights=weights
|
||||||
config=config,
|
)
|
||||||
prefix=f"{layer_id}.norm",
|
self.layer_norm = RMSNorm(768, eps=config.layer_norm_epsilon)
|
||||||
weights=weights,
|
self.layer_norm.scale = nn.Parameter(
|
||||||
bias=False,
|
weights.get_tensor(f"{layer_id}.norm.weight")
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
kv_cache,
|
|
||||||
attention_mask,
|
|
||||||
):
|
):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.layer_norm(hidden_states)
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
attn_outputs, past_kv_cache = self.mixer(hidden_states, kv_cache, attention_mask)
|
attn_outputs = self.mamba_block(hidden_states)
|
||||||
hidden_states = residual + attn_outputs
|
hidden_states = residual + attn_outputs
|
||||||
return hidden_states, residual
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class MambaModel(nn.Module):
|
class MambaModel(nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, config, weights):
|
||||||
@ -140,36 +218,43 @@ class MambaModel(nn.Module):
|
|||||||
prefix="backbone.embedding", weights=weights
|
prefix="backbone.embedding", weights=weights
|
||||||
)
|
)
|
||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList(
|
||||||
[ResidualBlock(f"backbone.layers.{layer_id}", config, weights) for layer_id in range(config.n_layer)]
|
[
|
||||||
|
ResidualBlock(f"backbone.layers.{layer_id}", config, weights)
|
||||||
|
for layer_id in range(config.n_layer)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
self.norm_f = FastRMSNorm.load(
|
|
||||||
prefix="backbone.norm_f",
|
# TODO: avoid hardcoded sizes and improve how we load the weights
|
||||||
weights=weights,
|
self.norm_f = RMSNorm(768, eps=config.layer_norm_epsilon)
|
||||||
eps=config.rms_norm_eps
|
self.norm_f.scale = nn.Parameter(weights.get_tensor(f"backbone.norm_f.weight"))
|
||||||
|
# use the same weights for the embedding and the final layer norm
|
||||||
|
self.lm_head = nn.Linear(768, config.vocab_size, bias=False)
|
||||||
|
self.lm_head.weight = nn.Parameter(
|
||||||
|
self.embed_tokens.weight[: config.vocab_size, :]
|
||||||
)
|
)
|
||||||
print("🌈 model init done")
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
past_input_ids: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
||||||
attention_mask: Optional[torch.ByteTensor] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
|
# TODO: dont use past_input_ids for the input_ids
|
||||||
|
# find a way to cache previous states/work
|
||||||
|
if past_input_ids is not None:
|
||||||
|
# append the contents to the input_ids
|
||||||
|
input_ids = torch.cat((past_input_ids, input_ids), dim=1)
|
||||||
|
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
seq_len = hidden_states.shape[1]
|
|
||||||
mask = None if seq_len <= 1 else attention_mask
|
|
||||||
|
|
||||||
past_key_values = [None] * len(self.blocks) if past_key_values is None else past_key_values
|
for _, block in enumerate(self.blocks):
|
||||||
|
hidden_states = block(hidden_states)
|
||||||
|
|
||||||
for index, block in enumerate(self.blocks):
|
final_hidden_states = self.norm_f(hidden_states)
|
||||||
hidden_states, new_key_values = block(hidden_states, past_key_values[index], mask)
|
after_lm_head = self.lm_head(final_hidden_states)
|
||||||
past_key_values[index] = new_key_values
|
return after_lm_head, input_ids
|
||||||
|
|
||||||
hidden_states = self.norm_f(hidden_states)
|
|
||||||
return hidden_states, past_key_values
|
|
||||||
|
|
||||||
|
# TODO: revisit if we want to use CausalLM
|
||||||
class MambaForCausalLM(torch.nn.Module):
|
class MambaForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -178,13 +263,24 @@ class MambaForCausalLM(torch.nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
|
# TODO: dont abuse past_key_values for the input_ids
|
||||||
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
||||||
|
# below are unused since this model is attention free
|
||||||
attention_mask: Optional[torch.ByteTensor] = None,
|
attention_mask: Optional[torch.ByteTensor] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
model_output = self.model(
|
model_output = self.model(
|
||||||
input_ids, past_key_values, attention_mask, return_dict, use_cache
|
input_ids,
|
||||||
|
past_input_ids=past_key_values,
|
||||||
|
)
|
||||||
|
logits = model_output[0]
|
||||||
|
past_hidden_states = model_output[1]
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=None,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=past_hidden_states,
|
||||||
|
hidden_states=None,
|
||||||
|
attentions=None,
|
||||||
)
|
)
|
||||||
print("🌈 model output done")
|
|
@ -1,17 +1,37 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from transformers import AutoConfig, AutoTokenizer
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
from text_generation_server.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
from text_generation_server.models.custom_modeling.mamba_modeling import MambaConfig, MambaForCausalLM
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
|
from text_generation_server.models.custom_modeling.mamba_modeling import (
|
||||||
|
MambaConfig,
|
||||||
|
MambaForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MambaCausalLMBatch(CausalLMBatch):
|
||||||
|
@classmethod
|
||||||
|
def from_pb(
|
||||||
|
cls,
|
||||||
|
pb: generate_pb2.Batch,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
) -> "CausalLMBatch":
|
||||||
|
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
|
||||||
|
batch.keys_head_dim_last = False
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
class Mamba(CausalLM):
|
class Mamba(CausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user