mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 12:02:08 +00:00
feat: add optimization and first pass of integration test
This commit is contained in:
parent
966f3ba35c
commit
5b6f9259c1
@ -0,0 +1,84 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 5089,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2748,
|
||||
"logprob": -9.7421875,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -2.4824219,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -2.4824219,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 50274,
|
||||
"logprob": -1.7880859,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 92,
|
||||
"logprob": -2.0703125,
|
||||
"special": false,
|
||||
"text": "{"
|
||||
},
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -0.04827881,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 50270,
|
||||
"logprob": -0.18896484,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"logprob": -1.5234375,
|
||||
"special": false,
|
||||
"text": "\""
|
||||
},
|
||||
{
|
||||
"id": 9629,
|
||||
"logprob": -2.8203125,
|
||||
"special": false,
|
||||
"text": "request"
|
||||
},
|
||||
{
|
||||
"id": 1381,
|
||||
"logprob": -0.78759766,
|
||||
"special": false,
|
||||
"text": "\":"
|
||||
},
|
||||
{
|
||||
"id": 551,
|
||||
"logprob": -0.49169922,
|
||||
"special": false,
|
||||
"text": " {"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\n {\n \"request\": {"
|
||||
}
|
@ -0,0 +1,99 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2502,
|
||||
"logprob": null,
|
||||
"text": " red"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.5234375,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 8862,
|
||||
"logprob": -3.4746094,
|
||||
"text": " yellow"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.43579102,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 209,
|
||||
"logprob": -8.2421875,
|
||||
"text": " "
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 2764,
|
||||
"logprob": -0.37573242,
|
||||
"special": false,
|
||||
"text": "umber"
|
||||
},
|
||||
{
|
||||
"id": 285,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " and"
|
||||
},
|
||||
{
|
||||
"id": 3168,
|
||||
"logprob": -0.9013672,
|
||||
"special": false,
|
||||
"text": " white"
|
||||
},
|
||||
{
|
||||
"id": 28,
|
||||
"logprob": -1.2314453,
|
||||
"special": false,
|
||||
"text": ";"
|
||||
},
|
||||
{
|
||||
"id": 253,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 3295,
|
||||
"logprob": -1.2167969,
|
||||
"special": false,
|
||||
"text": " color"
|
||||
},
|
||||
{
|
||||
"id": 273,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 697,
|
||||
"logprob": -2.1015625,
|
||||
"special": false,
|
||||
"text": " its"
|
||||
},
|
||||
{
|
||||
"id": 17433,
|
||||
"logprob": -2.4296875,
|
||||
"special": false,
|
||||
"text": " unders"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "blue, red, yellow, \number and white; the color of its unders"
|
||||
}
|
62
integration-tests/models/test_fused_kernel_mamba.py
Normal file
62
integration-tests/models/test_fused_kernel_mamba.py
Normal file
@ -0,0 +1,62 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def fused_kernel_mamba_handle(launcher):
|
||||
with launcher("state-spaces/mamba-130m", num_shard=1) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def fused_kernel_mamba(fused_kernel_mamba_handle):
|
||||
await fused_kernel_mamba_handle.health(300)
|
||||
return fused_kernel_mamba_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_fused_kernel_mamba(fused_kernel_mamba, response_snapshot):
|
||||
response = await fused_kernel_mamba.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_fused_kernel_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
||||
response = await fused_kernel_mamba.generate(
|
||||
"blue, red, yellow, ",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["test"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
# TODO: fix so the input is not included in the output
|
||||
assert response.generated_text == "blue, red, yellow, \number and white; the color of its unders"
|
||||
assert response == response_snapshot
|
||||
|
||||
# TODO: fix `Expected x0.dim() == 2 to be true, but got false.`
|
||||
# 94: `hidden_states, _ = self.layer_norm(hidden_states.squeeze(0))`
|
||||
# NOTE: the fast layer norm has strict requirements on the input shape
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.private
|
||||
# async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
|
||||
# responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4)
|
||||
|
||||
# assert len(responses) == 4
|
||||
# assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
# assert responses == response_snapshot
|
@ -5,12 +5,12 @@ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inne
|
||||
from torch import nn
|
||||
from typing import Optional, List, Tuple, Any
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
import torch.nn.functional as F
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
FastRMSNorm,
|
||||
)
|
||||
|
||||
class MambaConfig(PretrainedConfig):
|
||||
@ -41,151 +41,78 @@ class MambaConfig(PretrainedConfig):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class MambaBlock(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.dt_rank = (config.d_model + 15) // 16
|
||||
self.x_proj_weight = weights.get_tensor(f"{prefix}.x_proj.weight")
|
||||
self.dt_proj_weight = weights.get_tensor(f"{prefix}.dt_proj.weight")
|
||||
self.dt_proj_bias = weights.get_tensor(f"{prefix}.dt_proj.bias")
|
||||
self.out_proj_weight = weights.get_tensor(f"{prefix}.out_proj.weight")
|
||||
self.out_proj_bias = None
|
||||
# TODO: avoid loading the same weights twice
|
||||
self.in_proj_weight = weights.get_tensor(f"{prefix}.in_proj.weight")
|
||||
self.in_proj_bias = None
|
||||
self.in_proj = TensorParallelColumnLinear.load(
|
||||
config=config,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
config=config, prefix=f"{prefix}.in_proj", weights=weights, bias=False
|
||||
)
|
||||
self.conv1d = nn.Conv1d(
|
||||
config.d_inner,
|
||||
config.d_inner,
|
||||
kernel_size=config.d_conv,
|
||||
groups=config.d_inner,
|
||||
padding=config.d_conv - 1,
|
||||
)
|
||||
self.conv1d.weight = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.weight"))
|
||||
self.conv1d.bias = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.bias"))
|
||||
self.A_log = nn.Parameter(weights.get_tensor(f"{prefix}.A_log"))
|
||||
self.D = nn.Parameter(weights.get_tensor(f"{prefix}.D"))
|
||||
# helper for loading weights
|
||||
self.load_weights(prefix, weights)
|
||||
|
||||
def forward(self, index, hidden_states, past_transformed_state):
|
||||
projected_states = self.in_proj(hidden_states)
|
||||
|
||||
A = -torch.exp(self.A_log.float())
|
||||
def load_weights(self, prefix, weights):
|
||||
weight_names = ["x_proj.weight", "dt_proj.weight", "dt_proj.bias",
|
||||
"out_proj.weight", "in_proj.weight",
|
||||
"conv1d.weight", "conv1d.bias", "A_log", "D"]
|
||||
for name in weight_names:
|
||||
param_name = name.replace('.', '_')
|
||||
setattr(self, param_name, nn.Parameter(weights.get_tensor(f"{prefix}.{name}")))
|
||||
self.out_proj_bias = None
|
||||
self.negA = -torch.exp(self.A_log.float())
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
projected_states = self.in_proj(hidden_states).transpose(1,2)
|
||||
# conv1d, ssm, and selective_scan are all fused into one kernel
|
||||
attn_outputs = mamba_inner_fn(
|
||||
projected_states.transpose(1,2),
|
||||
self.conv1d.weight,
|
||||
self.conv1d.bias,
|
||||
projected_states,
|
||||
self.conv1d_weight,
|
||||
self.conv1d_bias,
|
||||
self.x_proj_weight,
|
||||
self.dt_proj_weight,
|
||||
self.out_proj_weight,
|
||||
self.out_proj_bias,
|
||||
A,
|
||||
self.negA,
|
||||
None,
|
||||
None,
|
||||
self.D.float(),
|
||||
delta_bias=self.dt_proj_bias.float(),
|
||||
delta_softplus=True,
|
||||
)
|
||||
|
||||
return attn_outputs, projected_states
|
||||
|
||||
|
||||
# TODO: prefer a more optimized implementation of RMSNorm if possible
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, num_features, eps=1e-8):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.scale = nn.Parameter(torch.ones(num_features))
|
||||
|
||||
def forward(self, x):
|
||||
rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
|
||||
x = x / rms
|
||||
return self.scale * x
|
||||
|
||||
return attn_outputs
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.mamba_block = MambaBlock(
|
||||
prefix=f"{layer_id}.mixer", config=config, weights=weights
|
||||
)
|
||||
self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.layer_norm.scale = nn.Parameter(
|
||||
weights.get_tensor(f"{layer_id}.norm.weight")
|
||||
)
|
||||
self.mamba_block = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights)
|
||||
self.layer_norm = FastRMSNorm.load(prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
index,
|
||||
hidden_states,
|
||||
past_transformed_state,
|
||||
hidden_states: torch.Tensor,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
attn_outputs, transformed_states = self.mamba_block(
|
||||
index, hidden_states, past_transformed_state
|
||||
)
|
||||
hidden_states = residual + attn_outputs
|
||||
return hidden_states, transformed_states
|
||||
|
||||
hidden_states, _ = self.layer_norm(hidden_states.squeeze(0))
|
||||
hidden_states = residual + self.mamba_block(hidden_states.unsqueeze(0))
|
||||
return hidden_states
|
||||
|
||||
class MambaModel(nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
self.tp_rank = weights.process_group.rank()
|
||||
self.tp_world_size = weights.process_group.size()
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="backbone.embedding", weights=weights
|
||||
)
|
||||
prefix = "backbone"
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
ResidualBlock(f"backbone.layers.{layer_id}", config, weights)
|
||||
for layer_id in range(config.n_layer)
|
||||
]
|
||||
[ResidualBlock(f"{prefix}.layers.{i}", config, weights) for i in range(config.n_layer)]
|
||||
)
|
||||
self.norm_f = FastRMSNorm.load(f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon)
|
||||
self.lm_head = TensorParallelColumnLinear.load(config, f"{prefix}.embedding", weights, False)
|
||||
|
||||
# TODO: avoid hardcoded sizes and improve how we load the weights
|
||||
self.norm_f = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.norm_f.scale = nn.Parameter(weights.get_tensor(f"backbone.norm_f.weight"))
|
||||
# use the same weights for the embedding and the final layer norm
|
||||
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
||||
self.lm_head.weight = nn.Parameter(
|
||||
self.embed_tokens.weight[: config.vocab_size, :]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
past_input_ids: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
||||
past_transformed_states: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
# NOTE: we need all input_ids to compute the correct embeddings
|
||||
if past_input_ids is not None:
|
||||
input_ids = past_input_ids
|
||||
|
||||
def forward(self, input_ids: torch.Tensor):
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
for block in self.blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
past_transformed_states = (
|
||||
[None] * len(self.blocks)
|
||||
if past_transformed_states is None
|
||||
else past_transformed_states
|
||||
)
|
||||
|
||||
for index, block in enumerate(self.blocks):
|
||||
hidden_states, transformed_states = block(
|
||||
index, hidden_states, past_transformed_states[index]
|
||||
)
|
||||
past_transformed_states[index] = transformed_states
|
||||
|
||||
final_hidden_states = self.norm_f(hidden_states)
|
||||
after_lm_head = self.lm_head(final_hidden_states)
|
||||
return after_lm_head, input_ids, past_transformed_states
|
||||
final_hidden_states, _ = self.norm_f(hidden_states.squeeze(0))
|
||||
return self.lm_head(final_hidden_states.unsqueeze(0)), input_ids
|
||||
|
@ -30,12 +30,11 @@ from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
||||
|
||||
|
||||
class MambaCausalLMBatch(CausalLMBatch):
|
||||
past_transformed_states: Optional[List[torch.Tensor]]
|
||||
past_input_ids: Optional[torch.Tensor]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.past_input_ids = None
|
||||
self.past_transformed_states = None
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
@ -103,6 +102,10 @@ class Mamba(Model):
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return MambaCausalLMBatch
|
||||
|
||||
def warmup(self, batch) -> Optional[int]:
|
||||
# TODO: implement warmup for Mamba if needed
|
||||
return None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -116,19 +119,9 @@ class Mamba(Model):
|
||||
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
||||
start = time.time_ns()
|
||||
|
||||
input_ids = batch.input_ids
|
||||
past_input_ids = batch.past_input_ids
|
||||
past_transformed_states = batch.past_transformed_states
|
||||
input_ids = batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
|
||||
|
||||
model_output = self.model(
|
||||
input_ids,
|
||||
past_input_ids,
|
||||
past_transformed_states,
|
||||
)
|
||||
|
||||
logits = model_output[0]
|
||||
past_input_ids = model_output[1]
|
||||
past_transformed_states = model_output[2]
|
||||
logits, past_input_ids = self.model(input_ids)[:2]
|
||||
|
||||
# Results
|
||||
generations: List[Generation] = []
|
||||
@ -176,9 +169,6 @@ class Mamba(Model):
|
||||
all_input_ids.view(1, -1), logits[-1:, :]
|
||||
)
|
||||
|
||||
# add next token to past_input_ids
|
||||
past_input_ids = torch.cat([past_input_ids, next_token_id], dim=1)
|
||||
|
||||
# Append next token to all tokens
|
||||
all_input_ids = torch.cat([all_input_ids, next_token_id])
|
||||
new_input_length = input_length + 1
|
||||
@ -199,73 +189,94 @@ class Mamba(Model):
|
||||
if not stop:
|
||||
stopped = False
|
||||
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
output_text, _, _ = self.decode_token(
|
||||
all_input_ids[:, 0],
|
||||
prefix_offset=len(all_input_ids)
|
||||
- stopping_criteria.current_tokens
|
||||
- 1,
|
||||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
# Shard generations
|
||||
# All generations will be appended in the rust sharded client
|
||||
if i % self.world_size == self.rank:
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
output_text, _, _ = self.decode_token(
|
||||
all_input_ids[:, 0],
|
||||
prefix_offset=len(all_input_ids)
|
||||
- stopping_criteria.current_tokens
|
||||
- 1,
|
||||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
generated_text = GeneratedText(
|
||||
output_text, stopping_criteria.current_tokens, reason, seed
|
||||
)
|
||||
else:
|
||||
seed = None
|
||||
generated_text = None
|
||||
|
||||
generated_text = GeneratedText(
|
||||
output_text, stopping_criteria.current_tokens, reason, seed
|
||||
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
prefill_logprobs = [float("nan")] + torch.log_softmax(
|
||||
logits, -1
|
||||
).gather(1, all_input_ids[1:]).squeeze(1)[
|
||||
-new_input_length:-1
|
||||
].tolist()
|
||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
prefill_tokens = Tokens(
|
||||
prefill_token_ids,
|
||||
prefill_logprobs,
|
||||
prefill_texts,
|
||||
is_special=[],
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
past_input_ids = torch.cat([past_input_ids, next_token_id], dim=1)
|
||||
|
||||
|
||||
if top_n_tokens > 0:
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = Tokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
special_toptokens,
|
||||
)
|
||||
else:
|
||||
top_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
batch.batch_id,
|
||||
prefill_tokens,
|
||||
Tokens(
|
||||
[next_token_id_squeezed],
|
||||
[next_token_logprob],
|
||||
[next_token_text],
|
||||
[next_token_id_squeezed.item() in self.all_special_ids],
|
||||
),
|
||||
generated_text,
|
||||
top_tokens,
|
||||
)
|
||||
else:
|
||||
generated_text = None
|
||||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
prefill_logprobs = [float("nan")] + torch.log_softmax(
|
||||
logits, -1
|
||||
).gather(1, all_input_ids[1:]).squeeze(1)[-new_input_length:-1].tolist()
|
||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
prefill_tokens = Tokens(
|
||||
prefill_token_ids,
|
||||
prefill_logprobs,
|
||||
prefill_texts,
|
||||
is_special=[],
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
generations.append(generation)
|
||||
|
||||
generation = Generation(
|
||||
batch.batch_id,
|
||||
None,
|
||||
Tokens(
|
||||
[next_token_id_squeezed],
|
||||
[next_token_logprob],
|
||||
[next_token_text],
|
||||
[next_token_id_squeezed.item() in self.all_special_ids],
|
||||
),
|
||||
generated_text,
|
||||
None,
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
next_token_tensor = next_token_id_squeezed.view(1, 1)
|
||||
# Update values
|
||||
batch.input_ids = torch.cat(
|
||||
[batch.input_ids, next_token_tensor], dim=1
|
||||
)
|
||||
batch.all_input_ids[i] = all_input_ids
|
||||
batch.input_lengths[i] = new_input_length
|
||||
batch.prefix_offsets[i] = prefix_offset
|
||||
batch.read_offsets[i] = read_offset
|
||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||
# Update values
|
||||
batch.all_input_ids[i] = all_input_ids
|
||||
batch.input_lengths[i] = new_input_length
|
||||
batch.prefix_offsets[i] = prefix_offset
|
||||
batch.read_offsets[i] = read_offset
|
||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if stopped:
|
||||
@ -273,10 +284,7 @@ class Mamba(Model):
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, None, (forward_ns, decode_ns)
|
||||
|
||||
# Slice unused values from prefill
|
||||
batch.input_ids = batch.input_ids[:, :1]
|
||||
batch.past_input_ids = past_input_ids
|
||||
batch.past_transformed_states = past_transformed_states
|
||||
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
|
Loading…
Reference in New Issue
Block a user