feat: mvp single inference and explore integration

This commit is contained in:
drbh 2024-01-24 20:55:12 -05:00
parent bcf145733c
commit 35939a28c7
2 changed files with 196 additions and 80 deletions

View File

@ -1,57 +1,33 @@
import torch import torch
import torch.distributed import torch.distributed
import math
from torch import nn from torch import nn
from typing import Optional, List, Tuple, Any from typing import Optional, List, Tuple, Any
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_outputs import CausalLMOutputWithPast
import torch.nn.functional as F
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelHead,
FastLinear,
FastRMSNorm,
) )
class MambaConfig(PretrainedConfig): class MambaConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
vocab_size=51200, vocab_size=50280,
n_positions=2048,
n_embd=2560,
n_layer=32, n_layer=32,
n_inner=None,
n_head=32,
rotary_dim=32,
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
tie_word_embeddings=False, tie_word_embeddings=False,
pad_vocab_size_multiple=64,
pad_token_id=0, pad_token_id=0,
bos_token_id=1, bos_token_id=1,
eos_token_id=2, eos_token_id=2,
no_bias=False,
rms_norm_eps=1e-8,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer self.n_layer = n_layer
self.n_inner = n_inner
self.n_head = n_head
self.rotary_dim = rotary_dim
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.tie_word_embeddings = tie_word_embeddings
self.pad_vocab_size_multiple = pad_vocab_size_multiple
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.no_bias = no_bias
self.rms_norm_eps = rms_norm_eps
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
@ -61,18 +37,29 @@ class MambaConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
class MambaBlock(nn.Module): class MambaBlock(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
# TODO: adjust how weights are loaded
# conv1d 768*2, 768*2, 4 # TODO: use model config to set the dt_rank instead of hardcoding it
self.conv1 = nn.Conv1d(768, 768, 4) d_inner = 768 * 2
# add weight and bias to conv1 d_conv = 4
self.conv1.weight = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.weight").transpose(0, 1)) self.dt_rank = (768 + 15) // 16
# TODO: improve how we load the conv1d weights
# explore a transposed conv1d that avoids the need for
# a transpose during inference
self.conv1 = nn.Conv1d(
d_inner,
d_inner,
kernel_size=d_conv,
groups=d_inner,
padding=d_conv - 1,
)
self.conv1.weight = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.weight"))
self.conv1.bias = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.bias")) self.conv1.bias = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.bias"))
# TODO: load weights in correctly for other operations
self.dt_proj = TensorParallelColumnLinear.load( self.dt_proj = TensorParallelColumnLinear.load(
config=config, config=config,
prefix=f"{prefix}.dt_proj", prefix=f"{prefix}.dt_proj",
@ -91,45 +78,136 @@ class MambaBlock(nn.Module):
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.A_log = nn.Parameter(torch.randn(config.n_head, config.n_head, config.rotary_dim)) self.out_proj = TensorParallelColumnLinear.load(
self.D = nn.Parameter(torch.randn(config.n_head, config.rotary_dim)) config=config,
prefix=f"{prefix}.out_proj",
weights=weights,
bias=False,
)
def forward( # TODO: improve how we load the weights
self, self.A_log = nn.Parameter(weights.get_tensor(f"{prefix}.A_log"))
hidden_states, self.D = nn.Parameter(weights.get_tensor(f"{prefix}.D"))
past_kv_cache,
attention_mask=None, def selective_scan(
self, input_tensor, delta, a_tensor, b_tensor, c_tensor, d_tensor
): ):
hidden_states_in_proj = self.in_proj(hidden_states) batch_size, sequence_length, input_dim = input_tensor.shape
hidden_states_and_residual = torch.chunk(hidden_states_in_proj, 2, dim=-1) num_cols = a_tensor.shape[1]
hs, res = hidden_states_and_residual[0], hidden_states_and_residual[1] # TODO: revisit this math to avoid the transposes when possible
# reshape and process delta
delta = delta.transpose(1, 2).view((batch_size, input_dim, sequence_length, 1))
exp_delta_a = (delta * a_tensor.view((1, input_dim, 1, num_cols))).exp()
# calc involving delta, b_tensor, and input_tensor
delta_b_input = (
delta
* b_tensor.view((batch_size, 1, sequence_length, num_cols))
* input_tensor.transpose(1, 2).view(
(batch_size, input_dim, sequence_length, 1)
)
)
# init output tensor
output_tensor = torch.zeros(
(batch_size, input_dim, num_cols),
dtype=exp_delta_a.dtype,
device=exp_delta_a.device,
)
# iterate over sequence_length
output_sequence = []
for i in range(sequence_length):
multiplier = exp_delta_a[:, :, i]
output_tensor = (multiplier * output_tensor) + delta_b_input[:, :, i]
y = output_tensor.matmul(c_tensor[:, i, :].unsqueeze(2)).squeeze(2)
output_sequence.append(y)
stacked_output = torch.stack(output_sequence, 1)
return stacked_output + input_tensor * d_tensor
def ssm(self, hidden_states):
_input_dim, num_cols = self.A_log.shape
negative_exponential_a = self.A_log.exp().neg()
d_matrix = self.D
projected_hidden_states = self.x_proj(hidden_states)
# narrow operations for delta, b, and c
delta = projected_hidden_states.narrow(-1, 0, self.dt_rank)
b_matrix = projected_hidden_states.narrow(-1, self.dt_rank, num_cols)
c_matrix = projected_hidden_states.narrow(-1, self.dt_rank + num_cols, num_cols)
# process delta
delta = self.dt_proj(delta)
delta = torch.log(torch.exp(delta) + 1)
# apply selective scan
selective_scan_output = self.selective_scan(
hidden_states, delta, negative_exponential_a, b_matrix, c_matrix, d_matrix
)
return selective_scan_output
def forward(self, hidden_states):
sequence_length = hidden_states.shape[1]
projected_states = self.in_proj(hidden_states)
split_states = torch.chunk(projected_states, 2, dim=-1)
transformed_states, residual_states = split_states
# TODO: avoid the transpose by using a transposed conv1d
# apply convolution and narrowing operation
conv_output = (
self.conv1(transformed_states.transpose(1, 2))
.narrow(-1, 0, sequence_length)
.transpose(1, 2)
)
# apply silu (Swish) activation function
activated_transformed = F.silu(conv_output)
activated_residual = F.silu(residual_states)
# Subsequent operations
output = self.ssm(activated_transformed)
combined_output = output * activated_residual
return self.out_proj(combined_output)
# TODO: prefer a more optimized implementation of RMSNorm if possible
class RMSNorm(nn.Module):
def __init__(self, num_features, eps=1e-8):
super().__init__()
self.num_features = num_features
self.eps = eps
self.scale = nn.Parameter(torch.ones(num_features))
def forward(self, x):
rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
x = x / rms
return self.scale * x
import ipdb; ipdb.set_trace()
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, layer_id, config, weights):
super().__init__() super().__init__()
self.layer_id = layer_id self.layer_id = layer_id
self.mixer = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights) self.mamba_block = MambaBlock(
self.layer_norm = FastLinear.load( prefix=f"{layer_id}.mixer", config=config, weights=weights
config=config, )
prefix=f"{layer_id}.norm", self.layer_norm = RMSNorm(768, eps=config.layer_norm_epsilon)
weights=weights, self.layer_norm.scale = nn.Parameter(
bias=False, weights.get_tensor(f"{layer_id}.norm.weight")
) )
def forward( def forward(
self, self,
hidden_states, hidden_states,
kv_cache,
attention_mask,
): ):
residual = hidden_states residual = hidden_states
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
attn_outputs, past_kv_cache = self.mixer(hidden_states, kv_cache, attention_mask) attn_outputs = self.mamba_block(hidden_states)
hidden_states = residual + attn_outputs hidden_states = residual + attn_outputs
return hidden_states, residual return hidden_states
class MambaModel(nn.Module): class MambaModel(nn.Module):
def __init__(self, config, weights): def __init__(self, config, weights):
@ -140,36 +218,43 @@ class MambaModel(nn.Module):
prefix="backbone.embedding", weights=weights prefix="backbone.embedding", weights=weights
) )
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ResidualBlock(f"backbone.layers.{layer_id}", config, weights) for layer_id in range(config.n_layer)] [
ResidualBlock(f"backbone.layers.{layer_id}", config, weights)
for layer_id in range(config.n_layer)
]
) )
self.norm_f = FastRMSNorm.load(
prefix="backbone.norm_f", # TODO: avoid hardcoded sizes and improve how we load the weights
weights=weights, self.norm_f = RMSNorm(768, eps=config.layer_norm_epsilon)
eps=config.rms_norm_eps self.norm_f.scale = nn.Parameter(weights.get_tensor(f"backbone.norm_f.weight"))
# use the same weights for the embedding and the final layer norm
self.lm_head = nn.Linear(768, config.vocab_size, bias=False)
self.lm_head.weight = nn.Parameter(
self.embed_tokens.weight[: config.vocab_size, :]
) )
print("🌈 model init done")
def forward( def forward(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, past_input_ids: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# TODO: dont use past_input_ids for the input_ids
# find a way to cache previous states/work
if past_input_ids is not None:
# append the contents to the input_ids
input_ids = torch.cat((past_input_ids, input_ids), dim=1)
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
seq_len = hidden_states.shape[1]
mask = None if seq_len <= 1 else attention_mask
past_key_values = [None] * len(self.blocks) if past_key_values is None else past_key_values for _, block in enumerate(self.blocks):
hidden_states = block(hidden_states)
for index, block in enumerate(self.blocks): final_hidden_states = self.norm_f(hidden_states)
hidden_states, new_key_values = block(hidden_states, past_key_values[index], mask) after_lm_head = self.lm_head(final_hidden_states)
past_key_values[index] = new_key_values return after_lm_head, input_ids
hidden_states = self.norm_f(hidden_states)
return hidden_states, past_key_values
# TODO: revisit if we want to use CausalLM
class MambaForCausalLM(torch.nn.Module): class MambaForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, config, weights):
super().__init__() super().__init__()
@ -178,13 +263,24 @@ class MambaForCausalLM(torch.nn.Module):
def forward( def forward(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
# TODO: dont abuse past_key_values for the input_ids
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
# below are unused since this model is attention free
attention_mask: Optional[torch.ByteTensor] = None, attention_mask: Optional[torch.ByteTensor] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
model_output = self.model( model_output = self.model(
input_ids, past_key_values, attention_mask, return_dict, use_cache input_ids,
past_input_ids=past_key_values,
)
logits = model_output[0]
past_hidden_states = model_output[1]
return CausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=past_hidden_states,
hidden_states=None,
attentions=None,
) )
print("🌈 model output done")

View File

@ -1,17 +1,37 @@
import torch import torch
import torch.distributed import torch.distributed
from transformers import AutoConfig, AutoTokenizer from transformers import AutoTokenizer, PreTrainedTokenizerBase
from typing import Optional, List, Tuple from typing import Optional
from text_generation_server.models import CausalLM from text_generation_server.models import CausalLM
from text_generation_server.models.custom_modeling.mamba_modeling import MambaConfig, MambaForCausalLM from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.models.custom_modeling.mamba_modeling import (
MambaConfig,
MambaForCausalLM,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights, Weights,
) )
class MambaCausalLMBatch(CausalLMBatch):
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
batch.keys_head_dim_last = False
return batch
class Mamba(CausalLM): class Mamba(CausalLM):
def __init__( def __init__(
self, self,