feat: mvp single inference and explore integration

This commit is contained in:
drbh 2024-01-24 20:55:12 -05:00
parent bcf145733c
commit 35939a28c7
2 changed files with 196 additions and 80 deletions

View File

@ -1,57 +1,33 @@
import torch
import torch.distributed
import math
from torch import nn
from typing import Optional, List, Tuple, Any
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
import torch.nn.functional as F
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelHead,
FastLinear,
FastRMSNorm,
)
class MambaConfig(PretrainedConfig):
def __init__(
self,
vocab_size=51200,
n_positions=2048,
n_embd=2560,
vocab_size=50280,
n_layer=32,
n_inner=None,
n_head=32,
rotary_dim=32,
layer_norm_epsilon=1e-5,
tie_word_embeddings=False,
pad_vocab_size_multiple=64,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
no_bias=False,
rms_norm_eps=1e-8,
**kwargs,
):
self.vocab_size = vocab_size
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_inner = n_inner
self.n_head = n_head
self.rotary_dim = rotary_dim
self.layer_norm_epsilon = layer_norm_epsilon
self.tie_word_embeddings = tie_word_embeddings
self.pad_vocab_size_multiple = pad_vocab_size_multiple
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.no_bias = no_bias
self.rms_norm_eps = rms_norm_eps
super().__init__(
pad_token_id=pad_token_id,
@ -61,18 +37,29 @@ class MambaConfig(PretrainedConfig):
**kwargs,
)
class MambaBlock(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
# TODO: adjust how weights are loaded
# conv1d 768*2, 768*2, 4
self.conv1 = nn.Conv1d(768, 768, 4)
# add weight and bias to conv1
self.conv1.weight = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.weight").transpose(0, 1))
# TODO: use model config to set the dt_rank instead of hardcoding it
d_inner = 768 * 2
d_conv = 4
self.dt_rank = (768 + 15) // 16
# TODO: improve how we load the conv1d weights
# explore a transposed conv1d that avoids the need for
# a transpose during inference
self.conv1 = nn.Conv1d(
d_inner,
d_inner,
kernel_size=d_conv,
groups=d_inner,
padding=d_conv - 1,
)
self.conv1.weight = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.weight"))
self.conv1.bias = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.bias"))
# TODO: load weights in correctly for other operations
self.dt_proj = TensorParallelColumnLinear.load(
config=config,
prefix=f"{prefix}.dt_proj",
@ -91,45 +78,136 @@ class MambaBlock(nn.Module):
weights=weights,
bias=False,
)
self.A_log = nn.Parameter(torch.randn(config.n_head, config.n_head, config.rotary_dim))
self.D = nn.Parameter(torch.randn(config.n_head, config.rotary_dim))
self.out_proj = TensorParallelColumnLinear.load(
config=config,
prefix=f"{prefix}.out_proj",
weights=weights,
bias=False,
)
def forward(
self,
hidden_states,
past_kv_cache,
attention_mask=None,
# TODO: improve how we load the weights
self.A_log = nn.Parameter(weights.get_tensor(f"{prefix}.A_log"))
self.D = nn.Parameter(weights.get_tensor(f"{prefix}.D"))
def selective_scan(
self, input_tensor, delta, a_tensor, b_tensor, c_tensor, d_tensor
):
hidden_states_in_proj = self.in_proj(hidden_states)
hidden_states_and_residual = torch.chunk(hidden_states_in_proj, 2, dim=-1)
batch_size, sequence_length, input_dim = input_tensor.shape
num_cols = a_tensor.shape[1]
hs, res = hidden_states_and_residual[0], hidden_states_and_residual[1]
# TODO: revisit this math to avoid the transposes when possible
# reshape and process delta
delta = delta.transpose(1, 2).view((batch_size, input_dim, sequence_length, 1))
exp_delta_a = (delta * a_tensor.view((1, input_dim, 1, num_cols))).exp()
# calc involving delta, b_tensor, and input_tensor
delta_b_input = (
delta
* b_tensor.view((batch_size, 1, sequence_length, num_cols))
* input_tensor.transpose(1, 2).view(
(batch_size, input_dim, sequence_length, 1)
)
)
# init output tensor
output_tensor = torch.zeros(
(batch_size, input_dim, num_cols),
dtype=exp_delta_a.dtype,
device=exp_delta_a.device,
)
# iterate over sequence_length
output_sequence = []
for i in range(sequence_length):
multiplier = exp_delta_a[:, :, i]
output_tensor = (multiplier * output_tensor) + delta_b_input[:, :, i]
y = output_tensor.matmul(c_tensor[:, i, :].unsqueeze(2)).squeeze(2)
output_sequence.append(y)
stacked_output = torch.stack(output_sequence, 1)
return stacked_output + input_tensor * d_tensor
def ssm(self, hidden_states):
_input_dim, num_cols = self.A_log.shape
negative_exponential_a = self.A_log.exp().neg()
d_matrix = self.D
projected_hidden_states = self.x_proj(hidden_states)
# narrow operations for delta, b, and c
delta = projected_hidden_states.narrow(-1, 0, self.dt_rank)
b_matrix = projected_hidden_states.narrow(-1, self.dt_rank, num_cols)
c_matrix = projected_hidden_states.narrow(-1, self.dt_rank + num_cols, num_cols)
# process delta
delta = self.dt_proj(delta)
delta = torch.log(torch.exp(delta) + 1)
# apply selective scan
selective_scan_output = self.selective_scan(
hidden_states, delta, negative_exponential_a, b_matrix, c_matrix, d_matrix
)
return selective_scan_output
def forward(self, hidden_states):
sequence_length = hidden_states.shape[1]
projected_states = self.in_proj(hidden_states)
split_states = torch.chunk(projected_states, 2, dim=-1)
transformed_states, residual_states = split_states
# TODO: avoid the transpose by using a transposed conv1d
# apply convolution and narrowing operation
conv_output = (
self.conv1(transformed_states.transpose(1, 2))
.narrow(-1, 0, sequence_length)
.transpose(1, 2)
)
# apply silu (Swish) activation function
activated_transformed = F.silu(conv_output)
activated_residual = F.silu(residual_states)
# Subsequent operations
output = self.ssm(activated_transformed)
combined_output = output * activated_residual
return self.out_proj(combined_output)
# TODO: prefer a more optimized implementation of RMSNorm if possible
class RMSNorm(nn.Module):
def __init__(self, num_features, eps=1e-8):
super().__init__()
self.num_features = num_features
self.eps = eps
self.scale = nn.Parameter(torch.ones(num_features))
def forward(self, x):
rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
x = x / rms
return self.scale * x
import ipdb; ipdb.set_trace()
class ResidualBlock(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
self.layer_id = layer_id
self.mixer = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights)
self.layer_norm = FastLinear.load(
config=config,
prefix=f"{layer_id}.norm",
weights=weights,
bias=False,
self.mamba_block = MambaBlock(
prefix=f"{layer_id}.mixer", config=config, weights=weights
)
self.layer_norm = RMSNorm(768, eps=config.layer_norm_epsilon)
self.layer_norm.scale = nn.Parameter(
weights.get_tensor(f"{layer_id}.norm.weight")
)
def forward(
self,
hidden_states,
kv_cache,
attention_mask,
):
):
residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
attn_outputs, past_kv_cache = self.mixer(hidden_states, kv_cache, attention_mask)
attn_outputs = self.mamba_block(hidden_states)
hidden_states = residual + attn_outputs
return hidden_states, residual
return hidden_states
class MambaModel(nn.Module):
def __init__(self, config, weights):
@ -138,38 +216,45 @@ class MambaModel(nn.Module):
self.tp_world_size = weights.process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="backbone.embedding", weights=weights
)
)
self.blocks = nn.ModuleList(
[ResidualBlock(f"backbone.layers.{layer_id}", config, weights) for layer_id in range(config.n_layer)]
[
ResidualBlock(f"backbone.layers.{layer_id}", config, weights)
for layer_id in range(config.n_layer)
]
)
self.norm_f = FastRMSNorm.load(
prefix="backbone.norm_f",
weights=weights,
eps=config.rms_norm_eps
# TODO: avoid hardcoded sizes and improve how we load the weights
self.norm_f = RMSNorm(768, eps=config.layer_norm_epsilon)
self.norm_f.scale = nn.Parameter(weights.get_tensor(f"backbone.norm_f.weight"))
# use the same weights for the embedding and the final layer norm
self.lm_head = nn.Linear(768, config.vocab_size, bias=False)
self.lm_head.weight = nn.Parameter(
self.embed_tokens.weight[: config.vocab_size, :]
)
print("🌈 model init done")
def forward(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
past_input_ids: Optional[List[Tuple[torch.FloatTensor]]] = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# TODO: dont use past_input_ids for the input_ids
# find a way to cache previous states/work
if past_input_ids is not None:
# append the contents to the input_ids
input_ids = torch.cat((past_input_ids, input_ids), dim=1)
hidden_states = self.embed_tokens(input_ids)
seq_len = hidden_states.shape[1]
mask = None if seq_len <= 1 else attention_mask
past_key_values = [None] * len(self.blocks) if past_key_values is None else past_key_values
for _, block in enumerate(self.blocks):
hidden_states = block(hidden_states)
for index, block in enumerate(self.blocks):
hidden_states, new_key_values = block(hidden_states, past_key_values[index], mask)
past_key_values[index] = new_key_values
final_hidden_states = self.norm_f(hidden_states)
after_lm_head = self.lm_head(final_hidden_states)
return after_lm_head, input_ids
hidden_states = self.norm_f(hidden_states)
return hidden_states, past_key_values
# TODO: revisit if we want to use CausalLM
class MambaForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
@ -178,13 +263,24 @@ class MambaForCausalLM(torch.nn.Module):
def forward(
self,
input_ids: torch.LongTensor,
# TODO: dont abuse past_key_values for the input_ids
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
# below are unused since this model is attention free
attention_mask: Optional[torch.ByteTensor] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
model_output = self.model(
input_ids, past_key_values, attention_mask, return_dict, use_cache
input_ids,
past_input_ids=past_key_values,
)
logits = model_output[0]
past_hidden_states = model_output[1]
return CausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=past_hidden_states,
hidden_states=None,
attentions=None,
)
print("🌈 model output done")

View File

@ -1,17 +1,37 @@
import torch
import torch.distributed
from transformers import AutoConfig, AutoTokenizer
from typing import Optional, List, Tuple
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from typing import Optional
from text_generation_server.models import CausalLM
from text_generation_server.models.custom_modeling.mamba_modeling import MambaConfig, MambaForCausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.models.custom_modeling.mamba_modeling import (
MambaConfig,
MambaForCausalLM,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class MambaCausalLMBatch(CausalLMBatch):
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
batch.keys_head_dim_last = False
return batch
class Mamba(CausalLM):
def __init__(
self,
@ -59,4 +79,4 @@ class Mamba(CausalLM):
requires_padding=True,
dtype=dtype,
device=device,
)
)