feat: prefer custom model and produce correct output

This commit is contained in:
drbh 2024-01-25 17:07:37 -05:00
parent 35939a28c7
commit 1c32d53fc3
2 changed files with 266 additions and 61 deletions

View File

@ -17,6 +17,7 @@ class MambaConfig(PretrainedConfig):
def __init__(
self,
vocab_size=50280,
d_model=768,
n_layer=32,
layer_norm_epsilon=1e-5,
tie_word_embeddings=False,
@ -28,6 +29,9 @@ class MambaConfig(PretrainedConfig):
self.vocab_size = vocab_size
self.n_layer = n_layer
self.layer_norm_epsilon = layer_norm_epsilon
self.d_model = d_model
self.d_inner = d_model * 2
self.d_conv = 4
super().__init__(
pad_token_id=pad_token_id,
@ -43,19 +47,17 @@ class MambaBlock(nn.Module):
super().__init__()
# TODO: use model config to set the dt_rank instead of hardcoding it
d_inner = 768 * 2
d_conv = 4
self.dt_rank = (768 + 15) // 16
self.dt_rank = (config.d_model + 15) // 16
# TODO: improve how we load the conv1d weights
# explore a transposed conv1d that avoids the need for
# a transpose during inference
self.conv1 = nn.Conv1d(
d_inner,
d_inner,
kernel_size=d_conv,
groups=d_inner,
padding=d_conv - 1,
config.d_inner,
config.d_inner,
kernel_size=config.d_conv,
groups=config.d_inner,
padding=config.d_conv - 1,
)
self.conv1.weight = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.weight"))
self.conv1.bias = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.bias"))
@ -148,11 +150,32 @@ class MambaBlock(nn.Module):
)
return selective_scan_output
def forward(self, hidden_states):
def forward(self, index, hidden_states, past_transformed_state):
sequence_length = hidden_states.shape[1]
projected_states = self.in_proj(hidden_states)
split_states = torch.chunk(projected_states, 2, dim=-1)
transformed_states, residual_states = split_states
# minimal amount of new work on single hidden state (previous hidden state are cached)
only_last = hidden_states[:, -1, :]
projected_only_last = self.in_proj(only_last)
transformed_only_last, residual_only_last = torch.chunk(
projected_only_last, 2, dim=-1
)
if past_transformed_state is not None:
# build a new transformed_states tensor with past_transformed_state and transformed_only_last
new_transformed_states = torch.cat(
[past_transformed_state, transformed_only_last.unsqueeze(1)], dim=1
)
transformed_states = new_transformed_states
residual_states = residual_only_last
else:
# prefilling the cache with the last transformed state
projected_states = self.in_proj(hidden_states)
split_states = torch.chunk(projected_states, 2, dim=-1)
transformed_states, residual_states = split_states
# NOTE: we need the past hidden states to produce the correct output
# therefore we cannot simply compute the most recent and append it as we
# did for the transformed states
# TODO: avoid the transpose by using a transposed conv1d
# apply convolution and narrowing operation
@ -170,7 +193,8 @@ class MambaBlock(nn.Module):
output = self.ssm(activated_transformed)
combined_output = output * activated_residual
return self.out_proj(combined_output)
return self.out_proj(combined_output), transformed_states
# TODO: prefer a more optimized implementation of RMSNorm if possible
class RMSNorm(nn.Module):
@ -193,20 +217,24 @@ class ResidualBlock(nn.Module):
self.mamba_block = MambaBlock(
prefix=f"{layer_id}.mixer", config=config, weights=weights
)
self.layer_norm = RMSNorm(768, eps=config.layer_norm_epsilon)
self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.layer_norm.scale = nn.Parameter(
weights.get_tensor(f"{layer_id}.norm.weight")
)
def forward(
self,
index,
hidden_states,
past_transformed_state,
):
residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
attn_outputs = self.mamba_block(hidden_states)
attn_outputs, transformed_states = self.mamba_block(
index, hidden_states, past_transformed_state
)
hidden_states = residual + attn_outputs
return hidden_states
return hidden_states, transformed_states
class MambaModel(nn.Module):
@ -225,10 +253,10 @@ class MambaModel(nn.Module):
)
# TODO: avoid hardcoded sizes and improve how we load the weights
self.norm_f = RMSNorm(768, eps=config.layer_norm_epsilon)
self.norm_f = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.norm_f.scale = nn.Parameter(weights.get_tensor(f"backbone.norm_f.weight"))
# use the same weights for the embedding and the final layer norm
self.lm_head = nn.Linear(768, config.vocab_size, bias=False)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.lm_head.weight = nn.Parameter(
self.embed_tokens.weight[: config.vocab_size, :]
)
@ -237,50 +265,26 @@ class MambaModel(nn.Module):
self,
input_ids: torch.LongTensor,
past_input_ids: Optional[List[Tuple[torch.FloatTensor]]] = None,
past_transformed_states: Optional[List[Tuple[torch.FloatTensor]]] = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# TODO: dont use past_input_ids for the input_ids
# find a way to cache previous states/work
# NOTE: we need all input_ids to compute the correct embeddings
if past_input_ids is not None:
# append the contents to the input_ids
input_ids = torch.cat((past_input_ids, input_ids), dim=1)
input_ids = past_input_ids
hidden_states = self.embed_tokens(input_ids)
for _, block in enumerate(self.blocks):
hidden_states = block(hidden_states)
past_transformed_states = (
[None] * len(self.blocks)
if past_transformed_states is None
else past_transformed_states
)
for index, block in enumerate(self.blocks):
hidden_states, transformed_states = block(
index, hidden_states, past_transformed_states[index]
)
past_transformed_states[index] = transformed_states
final_hidden_states = self.norm_f(hidden_states)
after_lm_head = self.lm_head(final_hidden_states)
return after_lm_head, input_ids
# TODO: revisit if we want to use CausalLM
class MambaForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.model = MambaModel(config, weights)
def forward(
self,
input_ids: torch.LongTensor,
# TODO: dont abuse past_key_values for the input_ids
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
# below are unused since this model is attention free
attention_mask: Optional[torch.ByteTensor] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
model_output = self.model(
input_ids,
past_input_ids=past_key_values,
)
logits = model_output[0]
past_hidden_states = model_output[1]
return CausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=past_hidden_states,
hidden_states=None,
attentions=None,
)
return after_lm_head, input_ids, past_transformed_states

View File

@ -8,7 +8,6 @@ from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.models.custom_modeling.mamba_modeling import (
MambaConfig,
MambaForCausalLM,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
@ -17,8 +16,27 @@ from text_generation_server.utils import (
Weights,
)
import time
from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel
from text_generation_server.models import Model
from typing import Any, List, Optional, Tuple, Type
from text_generation_server.models.types import (
Batch,
Tokens,
Generation,
GeneratedText,
)
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
class MambaCausalLMBatch(CausalLMBatch):
past_transformed_states: Optional[List[torch.Tensor]]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.past_input_ids = None
self.past_transformed_states = None
@classmethod
def from_pb(
cls,
@ -32,7 +50,7 @@ class MambaCausalLMBatch(CausalLMBatch):
return batch
class Mamba(CausalLM):
class Mamba(Model):
def __init__(
self,
model_id: str,
@ -71,12 +89,195 @@ class Mamba(CausalLM):
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
model = MambaForCausalLM(config, weights)
model = MambaModel(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
super(Mamba, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return MambaCausalLMBatch
def forward(
self,
input_ids: torch.Tensor,
past: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
return self.model(
input_ids,
past=past,
)
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
start = time.time_ns()
input_ids = batch.input_ids
past_input_ids = batch.past_input_ids
past_transformed_states = batch.past_transformed_states
model_output = self.model(
input_ids,
past_input_ids,
past_transformed_states,
)
logits = model_output[0]
past_input_ids = model_output[1]
past_transformed_states = model_output[2]
# Results
generations: List[Generation] = []
stopped = True
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.log_softmax(logits[:, -1], -1),
)
start_decode = time.time_ns()
# Zipped iterator
iterator = zip(
batch.requests,
batch.input_lengths,
batch.prefix_offsets,
batch.read_offsets,
logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
batch.top_n_tokens,
batch_top_token_ids,
batch_top_token_logprobs,
)
# For each member of the batch
for i, (
request,
input_length,
prefix_offset,
read_offset,
logits,
next_token_chooser,
stopping_criteria,
all_input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
all_input_ids.view(1, -1), logits[-1:, :]
)
# add next token to past_input_ids
past_input_ids = torch.cat([past_input_ids, next_token_id], dim=1)
# Append next token to all tokens
all_input_ids = torch.cat([all_input_ids, next_token_id])
new_input_length = input_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids[:, 0], prefix_offset, read_offset
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token_id_squeezed,
next_token_text,
)
if not stop:
stopped = False
if stop:
# Decode generated tokens
output_text, _, _ = self.decode_token(
all_input_ids[:, 0],
prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens
- 1,
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
skip_special_tokens=True,
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
generated_text = None
# Prefill
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + torch.log_softmax(
logits, -1
).gather(1, all_input_ids[1:]).squeeze(1)[-new_input_length:-1].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = Tokens(
prefill_token_ids,
prefill_logprobs,
prefill_texts,
is_special=[],
)
else:
prefill_tokens = None
generation = Generation(
batch.batch_id,
None,
Tokens(
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
None,
)
generations.append(generation)
# Update values
batch.input_ids = torch.cat(
[batch.input_ids, torch.tensor([[next_token_id_squeezed]])], dim=1
)
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.max_input_length = max(batch.max_input_length, new_input_length)
# We finished all generations in the batch; there is no next batch
if stopped:
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
# Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1]
batch.past_input_ids = past_input_ids
batch.past_transformed_states = past_transformed_states
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)