mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 20:12:07 +00:00
feat: prefer custom model and produce correct output
This commit is contained in:
parent
35939a28c7
commit
1c32d53fc3
@ -17,6 +17,7 @@ class MambaConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50280,
|
||||
d_model=768,
|
||||
n_layer=32,
|
||||
layer_norm_epsilon=1e-5,
|
||||
tie_word_embeddings=False,
|
||||
@ -28,6 +29,9 @@ class MambaConfig(PretrainedConfig):
|
||||
self.vocab_size = vocab_size
|
||||
self.n_layer = n_layer
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.d_model = d_model
|
||||
self.d_inner = d_model * 2
|
||||
self.d_conv = 4
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
@ -43,19 +47,17 @@ class MambaBlock(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
# TODO: use model config to set the dt_rank instead of hardcoding it
|
||||
d_inner = 768 * 2
|
||||
d_conv = 4
|
||||
self.dt_rank = (768 + 15) // 16
|
||||
self.dt_rank = (config.d_model + 15) // 16
|
||||
|
||||
# TODO: improve how we load the conv1d weights
|
||||
# explore a transposed conv1d that avoids the need for
|
||||
# a transpose during inference
|
||||
self.conv1 = nn.Conv1d(
|
||||
d_inner,
|
||||
d_inner,
|
||||
kernel_size=d_conv,
|
||||
groups=d_inner,
|
||||
padding=d_conv - 1,
|
||||
config.d_inner,
|
||||
config.d_inner,
|
||||
kernel_size=config.d_conv,
|
||||
groups=config.d_inner,
|
||||
padding=config.d_conv - 1,
|
||||
)
|
||||
self.conv1.weight = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.weight"))
|
||||
self.conv1.bias = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.bias"))
|
||||
@ -148,11 +150,32 @@ class MambaBlock(nn.Module):
|
||||
)
|
||||
return selective_scan_output
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, index, hidden_states, past_transformed_state):
|
||||
sequence_length = hidden_states.shape[1]
|
||||
projected_states = self.in_proj(hidden_states)
|
||||
split_states = torch.chunk(projected_states, 2, dim=-1)
|
||||
transformed_states, residual_states = split_states
|
||||
|
||||
# minimal amount of new work on single hidden state (previous hidden state are cached)
|
||||
only_last = hidden_states[:, -1, :]
|
||||
projected_only_last = self.in_proj(only_last)
|
||||
transformed_only_last, residual_only_last = torch.chunk(
|
||||
projected_only_last, 2, dim=-1
|
||||
)
|
||||
|
||||
if past_transformed_state is not None:
|
||||
# build a new transformed_states tensor with past_transformed_state and transformed_only_last
|
||||
new_transformed_states = torch.cat(
|
||||
[past_transformed_state, transformed_only_last.unsqueeze(1)], dim=1
|
||||
)
|
||||
transformed_states = new_transformed_states
|
||||
residual_states = residual_only_last
|
||||
else:
|
||||
# prefilling the cache with the last transformed state
|
||||
projected_states = self.in_proj(hidden_states)
|
||||
split_states = torch.chunk(projected_states, 2, dim=-1)
|
||||
transformed_states, residual_states = split_states
|
||||
|
||||
# NOTE: we need the past hidden states to produce the correct output
|
||||
# therefore we cannot simply compute the most recent and append it as we
|
||||
# did for the transformed states
|
||||
|
||||
# TODO: avoid the transpose by using a transposed conv1d
|
||||
# apply convolution and narrowing operation
|
||||
@ -170,7 +193,8 @@ class MambaBlock(nn.Module):
|
||||
output = self.ssm(activated_transformed)
|
||||
combined_output = output * activated_residual
|
||||
|
||||
return self.out_proj(combined_output)
|
||||
return self.out_proj(combined_output), transformed_states
|
||||
|
||||
|
||||
# TODO: prefer a more optimized implementation of RMSNorm if possible
|
||||
class RMSNorm(nn.Module):
|
||||
@ -193,20 +217,24 @@ class ResidualBlock(nn.Module):
|
||||
self.mamba_block = MambaBlock(
|
||||
prefix=f"{layer_id}.mixer", config=config, weights=weights
|
||||
)
|
||||
self.layer_norm = RMSNorm(768, eps=config.layer_norm_epsilon)
|
||||
self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.layer_norm.scale = nn.Parameter(
|
||||
weights.get_tensor(f"{layer_id}.norm.weight")
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
index,
|
||||
hidden_states,
|
||||
past_transformed_state,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
attn_outputs = self.mamba_block(hidden_states)
|
||||
attn_outputs, transformed_states = self.mamba_block(
|
||||
index, hidden_states, past_transformed_state
|
||||
)
|
||||
hidden_states = residual + attn_outputs
|
||||
return hidden_states
|
||||
return hidden_states, transformed_states
|
||||
|
||||
|
||||
class MambaModel(nn.Module):
|
||||
@ -225,10 +253,10 @@ class MambaModel(nn.Module):
|
||||
)
|
||||
|
||||
# TODO: avoid hardcoded sizes and improve how we load the weights
|
||||
self.norm_f = RMSNorm(768, eps=config.layer_norm_epsilon)
|
||||
self.norm_f = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.norm_f.scale = nn.Parameter(weights.get_tensor(f"backbone.norm_f.weight"))
|
||||
# use the same weights for the embedding and the final layer norm
|
||||
self.lm_head = nn.Linear(768, config.vocab_size, bias=False)
|
||||
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
||||
self.lm_head.weight = nn.Parameter(
|
||||
self.embed_tokens.weight[: config.vocab_size, :]
|
||||
)
|
||||
@ -237,50 +265,26 @@ class MambaModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
past_input_ids: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
||||
past_transformed_states: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
# TODO: dont use past_input_ids for the input_ids
|
||||
# find a way to cache previous states/work
|
||||
# NOTE: we need all input_ids to compute the correct embeddings
|
||||
if past_input_ids is not None:
|
||||
# append the contents to the input_ids
|
||||
input_ids = torch.cat((past_input_ids, input_ids), dim=1)
|
||||
input_ids = past_input_ids
|
||||
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
for _, block in enumerate(self.blocks):
|
||||
hidden_states = block(hidden_states)
|
||||
past_transformed_states = (
|
||||
[None] * len(self.blocks)
|
||||
if past_transformed_states is None
|
||||
else past_transformed_states
|
||||
)
|
||||
|
||||
for index, block in enumerate(self.blocks):
|
||||
hidden_states, transformed_states = block(
|
||||
index, hidden_states, past_transformed_states[index]
|
||||
)
|
||||
past_transformed_states[index] = transformed_states
|
||||
|
||||
final_hidden_states = self.norm_f(hidden_states)
|
||||
after_lm_head = self.lm_head(final_hidden_states)
|
||||
return after_lm_head, input_ids
|
||||
|
||||
|
||||
# TODO: revisit if we want to use CausalLM
|
||||
class MambaForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
self.model = MambaModel(config, weights)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
# TODO: dont abuse past_key_values for the input_ids
|
||||
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
||||
# below are unused since this model is attention free
|
||||
attention_mask: Optional[torch.ByteTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
model_output = self.model(
|
||||
input_ids,
|
||||
past_input_ids=past_key_values,
|
||||
)
|
||||
logits = model_output[0]
|
||||
past_hidden_states = model_output[1]
|
||||
return CausalLMOutputWithPast(
|
||||
loss=None,
|
||||
logits=logits,
|
||||
past_key_values=past_hidden_states,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
return after_lm_head, input_ids, past_transformed_states
|
||||
|
@ -8,7 +8,6 @@ from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.models.custom_modeling.mamba_modeling import (
|
||||
MambaConfig,
|
||||
MambaForCausalLM,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import (
|
||||
@ -17,8 +16,27 @@ from text_generation_server.utils import (
|
||||
Weights,
|
||||
)
|
||||
|
||||
import time
|
||||
from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel
|
||||
from text_generation_server.models import Model
|
||||
from typing import Any, List, Optional, Tuple, Type
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
Tokens,
|
||||
Generation,
|
||||
GeneratedText,
|
||||
)
|
||||
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
||||
|
||||
|
||||
class MambaCausalLMBatch(CausalLMBatch):
|
||||
past_transformed_states: Optional[List[torch.Tensor]]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.past_input_ids = None
|
||||
self.past_transformed_states = None
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
@ -32,7 +50,7 @@ class MambaCausalLMBatch(CausalLMBatch):
|
||||
return batch
|
||||
|
||||
|
||||
class Mamba(CausalLM):
|
||||
class Mamba(Model):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
@ -71,12 +89,195 @@ class Mamba(CausalLM):
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
model = MambaForCausalLM(config, weights)
|
||||
model = MambaModel(config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
super(Mamba, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return MambaCausalLMBatch
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
past: Optional[List[torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.model(
|
||||
input_ids,
|
||||
past=past,
|
||||
)
|
||||
|
||||
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
||||
start = time.time_ns()
|
||||
|
||||
input_ids = batch.input_ids
|
||||
past_input_ids = batch.past_input_ids
|
||||
past_transformed_states = batch.past_transformed_states
|
||||
|
||||
model_output = self.model(
|
||||
input_ids,
|
||||
past_input_ids,
|
||||
past_transformed_states,
|
||||
)
|
||||
|
||||
logits = model_output[0]
|
||||
past_input_ids = model_output[1]
|
||||
past_transformed_states = model_output[2]
|
||||
|
||||
# Results
|
||||
generations: List[Generation] = []
|
||||
stopped = True
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens,
|
||||
batch.top_n_tokens_tensor,
|
||||
torch.log_softmax(logits[:, -1], -1),
|
||||
)
|
||||
|
||||
start_decode = time.time_ns()
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.requests,
|
||||
batch.input_lengths,
|
||||
batch.prefix_offsets,
|
||||
batch.read_offsets,
|
||||
logits,
|
||||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
batch.all_input_ids,
|
||||
batch.top_n_tokens,
|
||||
batch_top_token_ids,
|
||||
batch_top_token_logprobs,
|
||||
)
|
||||
|
||||
# For each member of the batch
|
||||
for i, (
|
||||
request,
|
||||
input_length,
|
||||
prefix_offset,
|
||||
read_offset,
|
||||
logits,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
all_input_ids,
|
||||
top_n_tokens,
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(
|
||||
all_input_ids.view(1, -1), logits[-1:, :]
|
||||
)
|
||||
|
||||
# add next token to past_input_ids
|
||||
past_input_ids = torch.cat([past_input_ids, next_token_id], dim=1)
|
||||
|
||||
# Append next token to all tokens
|
||||
all_input_ids = torch.cat([all_input_ids, next_token_id])
|
||||
new_input_length = input_length + 1
|
||||
|
||||
# Generated token
|
||||
next_token_logprob = logprobs[-1, next_token_id]
|
||||
next_token_id_squeezed = next_token_id.squeeze()
|
||||
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||
all_input_ids[:, 0], prefix_offset, read_offset
|
||||
)
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(
|
||||
next_token_id_squeezed,
|
||||
next_token_text,
|
||||
)
|
||||
|
||||
if not stop:
|
||||
stopped = False
|
||||
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
output_text, _, _ = self.decode_token(
|
||||
all_input_ids[:, 0],
|
||||
prefix_offset=len(all_input_ids)
|
||||
- stopping_criteria.current_tokens
|
||||
- 1,
|
||||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
generated_text = GeneratedText(
|
||||
output_text, stopping_criteria.current_tokens, reason, seed
|
||||
)
|
||||
else:
|
||||
generated_text = None
|
||||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
prefill_logprobs = [float("nan")] + torch.log_softmax(
|
||||
logits, -1
|
||||
).gather(1, all_input_ids[1:]).squeeze(1)[-new_input_length:-1].tolist()
|
||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
prefill_tokens = Tokens(
|
||||
prefill_token_ids,
|
||||
prefill_logprobs,
|
||||
prefill_texts,
|
||||
is_special=[],
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
batch.batch_id,
|
||||
None,
|
||||
Tokens(
|
||||
[next_token_id_squeezed],
|
||||
[next_token_logprob],
|
||||
[next_token_text],
|
||||
[next_token_id_squeezed.item() in self.all_special_ids],
|
||||
),
|
||||
generated_text,
|
||||
None,
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
|
||||
# Update values
|
||||
batch.input_ids = torch.cat(
|
||||
[batch.input_ids, torch.tensor([[next_token_id_squeezed]])], dim=1
|
||||
)
|
||||
batch.all_input_ids[i] = all_input_ids
|
||||
batch.input_lengths[i] = new_input_length
|
||||
batch.prefix_offsets[i] = prefix_offset
|
||||
batch.read_offsets[i] = read_offset
|
||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if stopped:
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, None, (forward_ns, decode_ns)
|
||||
|
||||
# Slice unused values from prefill
|
||||
batch.input_ids = batch.input_ids[:, :1]
|
||||
batch.past_input_ids = past_input_ids
|
||||
batch.past_transformed_states = past_transformed_states
|
||||
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, batch, (forward_ns, decode_ns)
|
||||
|
Loading…
Reference in New Issue
Block a user