mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 20:12:07 +00:00
feat: add optimization and first pass of integration test
This commit is contained in:
parent
966f3ba35c
commit
5b6f9259c1
@ -0,0 +1,84 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 5089,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2748,
|
||||||
|
"logprob": -9.7421875,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -2.4824219,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -2.4824219,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 50274,
|
||||||
|
"logprob": -1.7880859,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 92,
|
||||||
|
"logprob": -2.0703125,
|
||||||
|
"special": false,
|
||||||
|
"text": "{"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -0.04827881,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 50270,
|
||||||
|
"logprob": -0.18896484,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"logprob": -1.5234375,
|
||||||
|
"special": false,
|
||||||
|
"text": "\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 9629,
|
||||||
|
"logprob": -2.8203125,
|
||||||
|
"special": false,
|
||||||
|
"text": "request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1381,
|
||||||
|
"logprob": -0.78759766,
|
||||||
|
"special": false,
|
||||||
|
"text": "\":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 551,
|
||||||
|
"logprob": -0.49169922,
|
||||||
|
"special": false,
|
||||||
|
"text": " {"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\n {\n \"request\": {"
|
||||||
|
}
|
@ -0,0 +1,99 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2502,
|
||||||
|
"logprob": null,
|
||||||
|
"text": " red"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.5234375,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8862,
|
||||||
|
"logprob": -3.4746094,
|
||||||
|
"text": " yellow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.43579102,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 209,
|
||||||
|
"logprob": -8.2421875,
|
||||||
|
"text": " "
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2764,
|
||||||
|
"logprob": -0.37573242,
|
||||||
|
"special": false,
|
||||||
|
"text": "umber"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 285,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " and"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3168,
|
||||||
|
"logprob": -0.9013672,
|
||||||
|
"special": false,
|
||||||
|
"text": " white"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28,
|
||||||
|
"logprob": -1.2314453,
|
||||||
|
"special": false,
|
||||||
|
"text": ";"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 253,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3295,
|
||||||
|
"logprob": -1.2167969,
|
||||||
|
"special": false,
|
||||||
|
"text": " color"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 273,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 697,
|
||||||
|
"logprob": -2.1015625,
|
||||||
|
"special": false,
|
||||||
|
"text": " its"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 17433,
|
||||||
|
"logprob": -2.4296875,
|
||||||
|
"special": false,
|
||||||
|
"text": " unders"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "blue, red, yellow, \number and white; the color of its unders"
|
||||||
|
}
|
62
integration-tests/models/test_fused_kernel_mamba.py
Normal file
62
integration-tests/models/test_fused_kernel_mamba.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def fused_kernel_mamba_handle(launcher):
|
||||||
|
with launcher("state-spaces/mamba-130m", num_shard=1) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def fused_kernel_mamba(fused_kernel_mamba_handle):
|
||||||
|
await fused_kernel_mamba_handle.health(300)
|
||||||
|
return fused_kernel_mamba_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_fused_kernel_mamba(fused_kernel_mamba, response_snapshot):
|
||||||
|
response = await fused_kernel_mamba.generate(
|
||||||
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_fused_kernel_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
||||||
|
response = await fused_kernel_mamba.generate(
|
||||||
|
"blue, red, yellow, ",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
# TODO: fix so the input is not included in the output
|
||||||
|
assert response.generated_text == "blue, red, yellow, \number and white; the color of its unders"
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
# TODO: fix `Expected x0.dim() == 2 to be true, but got false.`
|
||||||
|
# 94: `hidden_states, _ = self.layer_norm(hidden_states.squeeze(0))`
|
||||||
|
# NOTE: the fast layer norm has strict requirements on the input shape
|
||||||
|
# @pytest.mark.asyncio
|
||||||
|
# @pytest.mark.private
|
||||||
|
# async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
|
||||||
|
# responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4)
|
||||||
|
|
||||||
|
# assert len(responses) == 4
|
||||||
|
# assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
|
# assert responses == response_snapshot
|
@ -5,12 +5,12 @@ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inne
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from typing import Optional, List, Tuple, Any
|
from typing import Optional, List, Tuple, Any
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
class MambaConfig(PretrainedConfig):
|
class MambaConfig(PretrainedConfig):
|
||||||
@ -41,151 +41,78 @@ class MambaConfig(PretrainedConfig):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MambaBlock(nn.Module):
|
class MambaBlock(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dt_rank = (config.d_model + 15) // 16
|
|
||||||
self.x_proj_weight = weights.get_tensor(f"{prefix}.x_proj.weight")
|
|
||||||
self.dt_proj_weight = weights.get_tensor(f"{prefix}.dt_proj.weight")
|
|
||||||
self.dt_proj_bias = weights.get_tensor(f"{prefix}.dt_proj.bias")
|
|
||||||
self.out_proj_weight = weights.get_tensor(f"{prefix}.out_proj.weight")
|
|
||||||
self.out_proj_bias = None
|
|
||||||
# TODO: avoid loading the same weights twice
|
|
||||||
self.in_proj_weight = weights.get_tensor(f"{prefix}.in_proj.weight")
|
|
||||||
self.in_proj_bias = None
|
|
||||||
self.in_proj = TensorParallelColumnLinear.load(
|
self.in_proj = TensorParallelColumnLinear.load(
|
||||||
config=config,
|
config=config, prefix=f"{prefix}.in_proj", weights=weights, bias=False
|
||||||
prefix=f"{prefix}.in_proj",
|
|
||||||
weights=weights,
|
|
||||||
bias=False,
|
|
||||||
)
|
)
|
||||||
self.conv1d = nn.Conv1d(
|
# helper for loading weights
|
||||||
config.d_inner,
|
self.load_weights(prefix, weights)
|
||||||
config.d_inner,
|
|
||||||
kernel_size=config.d_conv,
|
|
||||||
groups=config.d_inner,
|
|
||||||
padding=config.d_conv - 1,
|
|
||||||
)
|
|
||||||
self.conv1d.weight = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.weight"))
|
|
||||||
self.conv1d.bias = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.bias"))
|
|
||||||
self.A_log = nn.Parameter(weights.get_tensor(f"{prefix}.A_log"))
|
|
||||||
self.D = nn.Parameter(weights.get_tensor(f"{prefix}.D"))
|
|
||||||
|
|
||||||
def forward(self, index, hidden_states, past_transformed_state):
|
def load_weights(self, prefix, weights):
|
||||||
projected_states = self.in_proj(hidden_states)
|
weight_names = ["x_proj.weight", "dt_proj.weight", "dt_proj.bias",
|
||||||
|
"out_proj.weight", "in_proj.weight",
|
||||||
A = -torch.exp(self.A_log.float())
|
"conv1d.weight", "conv1d.bias", "A_log", "D"]
|
||||||
|
for name in weight_names:
|
||||||
|
param_name = name.replace('.', '_')
|
||||||
|
setattr(self, param_name, nn.Parameter(weights.get_tensor(f"{prefix}.{name}")))
|
||||||
|
self.out_proj_bias = None
|
||||||
|
self.negA = -torch.exp(self.A_log.float())
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor):
|
||||||
|
projected_states = self.in_proj(hidden_states).transpose(1,2)
|
||||||
# conv1d, ssm, and selective_scan are all fused into one kernel
|
# conv1d, ssm, and selective_scan are all fused into one kernel
|
||||||
attn_outputs = mamba_inner_fn(
|
attn_outputs = mamba_inner_fn(
|
||||||
projected_states.transpose(1,2),
|
projected_states,
|
||||||
self.conv1d.weight,
|
self.conv1d_weight,
|
||||||
self.conv1d.bias,
|
self.conv1d_bias,
|
||||||
self.x_proj_weight,
|
self.x_proj_weight,
|
||||||
self.dt_proj_weight,
|
self.dt_proj_weight,
|
||||||
self.out_proj_weight,
|
self.out_proj_weight,
|
||||||
self.out_proj_bias,
|
self.out_proj_bias,
|
||||||
A,
|
self.negA,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
self.D.float(),
|
self.D.float(),
|
||||||
delta_bias=self.dt_proj_bias.float(),
|
delta_bias=self.dt_proj_bias.float(),
|
||||||
delta_softplus=True,
|
delta_softplus=True,
|
||||||
)
|
)
|
||||||
|
return attn_outputs
|
||||||
return attn_outputs, projected_states
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: prefer a more optimized implementation of RMSNorm if possible
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
def __init__(self, num_features, eps=1e-8):
|
|
||||||
super().__init__()
|
|
||||||
self.num_features = num_features
|
|
||||||
self.eps = eps
|
|
||||||
self.scale = nn.Parameter(torch.ones(num_features))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
|
|
||||||
x = x / rms
|
|
||||||
return self.scale * x
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_id = layer_id
|
self.mamba_block = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights)
|
||||||
self.mamba_block = MambaBlock(
|
self.layer_norm = FastRMSNorm.load(prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon)
|
||||||
prefix=f"{layer_id}.mixer", config=config, weights=weights
|
|
||||||
)
|
|
||||||
self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
|
||||||
self.layer_norm.scale = nn.Parameter(
|
|
||||||
weights.get_tensor(f"{layer_id}.norm.weight")
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
index,
|
hidden_states: torch.Tensor,
|
||||||
hidden_states,
|
|
||||||
past_transformed_state,
|
|
||||||
):
|
):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.layer_norm(hidden_states)
|
hidden_states, _ = self.layer_norm(hidden_states.squeeze(0))
|
||||||
attn_outputs, transformed_states = self.mamba_block(
|
hidden_states = residual + self.mamba_block(hidden_states.unsqueeze(0))
|
||||||
index, hidden_states, past_transformed_state
|
return hidden_states
|
||||||
)
|
|
||||||
hidden_states = residual + attn_outputs
|
|
||||||
return hidden_states, transformed_states
|
|
||||||
|
|
||||||
|
|
||||||
class MambaModel(nn.Module):
|
class MambaModel(nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_rank = weights.process_group.rank()
|
self.tp_rank = weights.process_group.rank()
|
||||||
self.tp_world_size = weights.process_group.size()
|
self.tp_world_size = weights.process_group.size()
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
prefix = "backbone"
|
||||||
prefix="backbone.embedding", weights=weights
|
|
||||||
)
|
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
|
||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList(
|
||||||
[
|
[ResidualBlock(f"{prefix}.layers.{i}", config, weights) for i in range(config.n_layer)]
|
||||||
ResidualBlock(f"backbone.layers.{layer_id}", config, weights)
|
|
||||||
for layer_id in range(config.n_layer)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
self.norm_f = FastRMSNorm.load(f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon)
|
||||||
|
self.lm_head = TensorParallelColumnLinear.load(config, f"{prefix}.embedding", weights, False)
|
||||||
|
|
||||||
# TODO: avoid hardcoded sizes and improve how we load the weights
|
def forward(self, input_ids: torch.Tensor):
|
||||||
self.norm_f = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
|
||||||
self.norm_f.scale = nn.Parameter(weights.get_tensor(f"backbone.norm_f.weight"))
|
|
||||||
# use the same weights for the embedding and the final layer norm
|
|
||||||
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
|
||||||
self.lm_head.weight = nn.Parameter(
|
|
||||||
self.embed_tokens.weight[: config.vocab_size, :]
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor,
|
|
||||||
past_input_ids: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
|
||||||
past_transformed_states: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
||||||
# NOTE: we need all input_ids to compute the correct embeddings
|
|
||||||
if past_input_ids is not None:
|
|
||||||
input_ids = past_input_ids
|
|
||||||
|
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
for block in self.blocks:
|
||||||
|
hidden_states = block(hidden_states)
|
||||||
|
|
||||||
past_transformed_states = (
|
final_hidden_states, _ = self.norm_f(hidden_states.squeeze(0))
|
||||||
[None] * len(self.blocks)
|
return self.lm_head(final_hidden_states.unsqueeze(0)), input_ids
|
||||||
if past_transformed_states is None
|
|
||||||
else past_transformed_states
|
|
||||||
)
|
|
||||||
|
|
||||||
for index, block in enumerate(self.blocks):
|
|
||||||
hidden_states, transformed_states = block(
|
|
||||||
index, hidden_states, past_transformed_states[index]
|
|
||||||
)
|
|
||||||
past_transformed_states[index] = transformed_states
|
|
||||||
|
|
||||||
final_hidden_states = self.norm_f(hidden_states)
|
|
||||||
after_lm_head = self.lm_head(final_hidden_states)
|
|
||||||
return after_lm_head, input_ids, past_transformed_states
|
|
||||||
|
@ -30,12 +30,11 @@ from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
|||||||
|
|
||||||
|
|
||||||
class MambaCausalLMBatch(CausalLMBatch):
|
class MambaCausalLMBatch(CausalLMBatch):
|
||||||
past_transformed_states: Optional[List[torch.Tensor]]
|
past_input_ids: Optional[torch.Tensor]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.past_input_ids = None
|
self.past_input_ids = None
|
||||||
self.past_transformed_states = None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
@ -103,6 +102,10 @@ class Mamba(Model):
|
|||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
def batch_type(self) -> Type[CausalLMBatch]:
|
||||||
return MambaCausalLMBatch
|
return MambaCausalLMBatch
|
||||||
|
|
||||||
|
def warmup(self, batch) -> Optional[int]:
|
||||||
|
# TODO: implement warmup for Mamba if needed
|
||||||
|
return None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -116,19 +119,9 @@ class Mamba(Model):
|
|||||||
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
||||||
start = time.time_ns()
|
start = time.time_ns()
|
||||||
|
|
||||||
input_ids = batch.input_ids
|
input_ids = batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
|
||||||
past_input_ids = batch.past_input_ids
|
|
||||||
past_transformed_states = batch.past_transformed_states
|
|
||||||
|
|
||||||
model_output = self.model(
|
logits, past_input_ids = self.model(input_ids)[:2]
|
||||||
input_ids,
|
|
||||||
past_input_ids,
|
|
||||||
past_transformed_states,
|
|
||||||
)
|
|
||||||
|
|
||||||
logits = model_output[0]
|
|
||||||
past_input_ids = model_output[1]
|
|
||||||
past_transformed_states = model_output[2]
|
|
||||||
|
|
||||||
# Results
|
# Results
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
@ -176,9 +169,6 @@ class Mamba(Model):
|
|||||||
all_input_ids.view(1, -1), logits[-1:, :]
|
all_input_ids.view(1, -1), logits[-1:, :]
|
||||||
)
|
)
|
||||||
|
|
||||||
# add next token to past_input_ids
|
|
||||||
past_input_ids = torch.cat([past_input_ids, next_token_id], dim=1)
|
|
||||||
|
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
all_input_ids = torch.cat([all_input_ids, next_token_id])
|
all_input_ids = torch.cat([all_input_ids, next_token_id])
|
||||||
new_input_length = input_length + 1
|
new_input_length = input_length + 1
|
||||||
@ -199,73 +189,94 @@ class Mamba(Model):
|
|||||||
if not stop:
|
if not stop:
|
||||||
stopped = False
|
stopped = False
|
||||||
|
|
||||||
if stop:
|
# Shard generations
|
||||||
# Decode generated tokens
|
# All generations will be appended in the rust sharded client
|
||||||
output_text, _, _ = self.decode_token(
|
if i % self.world_size == self.rank:
|
||||||
all_input_ids[:, 0],
|
if stop:
|
||||||
prefix_offset=len(all_input_ids)
|
# Decode generated tokens
|
||||||
- stopping_criteria.current_tokens
|
output_text, _, _ = self.decode_token(
|
||||||
- 1,
|
all_input_ids[:, 0],
|
||||||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
prefix_offset=len(all_input_ids)
|
||||||
skip_special_tokens=True,
|
- stopping_criteria.current_tokens
|
||||||
)
|
- 1,
|
||||||
# Get seed
|
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
skip_special_tokens=True,
|
||||||
seed = next_token_chooser.choice.seed
|
)
|
||||||
|
# Get seed
|
||||||
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
seed = next_token_chooser.choice.seed
|
||||||
|
else:
|
||||||
|
seed = None
|
||||||
|
|
||||||
|
generated_text = GeneratedText(
|
||||||
|
output_text, stopping_criteria.current_tokens, reason, seed
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
seed = None
|
generated_text = None
|
||||||
|
|
||||||
generated_text = GeneratedText(
|
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||||
output_text, stopping_criteria.current_tokens, reason, seed
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
|
prefill_logprobs = [float("nan")] + torch.log_softmax(
|
||||||
|
logits, -1
|
||||||
|
).gather(1, all_input_ids[1:]).squeeze(1)[
|
||||||
|
-new_input_length:-1
|
||||||
|
].tolist()
|
||||||
|
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||||
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
|
prefill_token_ids,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
prefill_tokens = Tokens(
|
||||||
|
prefill_token_ids,
|
||||||
|
prefill_logprobs,
|
||||||
|
prefill_texts,
|
||||||
|
is_special=[],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prefill_tokens = None
|
||||||
|
past_input_ids = torch.cat([past_input_ids, next_token_id], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
if top_n_tokens > 0:
|
||||||
|
toptoken_texts = self.tokenizer.batch_decode(
|
||||||
|
top_token_ids,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
special_toptokens = [
|
||||||
|
token_id in self.all_special_ids for token_id in top_token_ids
|
||||||
|
]
|
||||||
|
top_tokens = Tokens(
|
||||||
|
top_token_ids,
|
||||||
|
top_token_logprobs,
|
||||||
|
toptoken_texts,
|
||||||
|
special_toptokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
top_tokens = None
|
||||||
|
|
||||||
|
generation = Generation(
|
||||||
|
batch.batch_id,
|
||||||
|
prefill_tokens,
|
||||||
|
Tokens(
|
||||||
|
[next_token_id_squeezed],
|
||||||
|
[next_token_logprob],
|
||||||
|
[next_token_text],
|
||||||
|
[next_token_id_squeezed.item() in self.all_special_ids],
|
||||||
|
),
|
||||||
|
generated_text,
|
||||||
|
top_tokens,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
generated_text = None
|
|
||||||
|
|
||||||
# Prefill
|
generations.append(generation)
|
||||||
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
|
||||||
prefill_logprobs = [float("nan")] + torch.log_softmax(
|
|
||||||
logits, -1
|
|
||||||
).gather(1, all_input_ids[1:]).squeeze(1)[-new_input_length:-1].tolist()
|
|
||||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
|
||||||
prefill_token_ids,
|
|
||||||
clean_up_tokenization_spaces=False,
|
|
||||||
skip_special_tokens=False,
|
|
||||||
)
|
|
||||||
prefill_tokens = Tokens(
|
|
||||||
prefill_token_ids,
|
|
||||||
prefill_logprobs,
|
|
||||||
prefill_texts,
|
|
||||||
is_special=[],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prefill_tokens = None
|
|
||||||
|
|
||||||
generation = Generation(
|
# Update values
|
||||||
batch.batch_id,
|
batch.all_input_ids[i] = all_input_ids
|
||||||
None,
|
batch.input_lengths[i] = new_input_length
|
||||||
Tokens(
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
[next_token_id_squeezed],
|
batch.read_offsets[i] = read_offset
|
||||||
[next_token_logprob],
|
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||||
[next_token_text],
|
|
||||||
[next_token_id_squeezed.item() in self.all_special_ids],
|
|
||||||
),
|
|
||||||
generated_text,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
generations.append(generation)
|
|
||||||
next_token_tensor = next_token_id_squeezed.view(1, 1)
|
|
||||||
# Update values
|
|
||||||
batch.input_ids = torch.cat(
|
|
||||||
[batch.input_ids, next_token_tensor], dim=1
|
|
||||||
)
|
|
||||||
batch.all_input_ids[i] = all_input_ids
|
|
||||||
batch.input_lengths[i] = new_input_length
|
|
||||||
batch.prefix_offsets[i] = prefix_offset
|
|
||||||
batch.read_offsets[i] = read_offset
|
|
||||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
|
||||||
|
|
||||||
# We finished all generations in the batch; there is no next batch
|
# We finished all generations in the batch; there is no next batch
|
||||||
if stopped:
|
if stopped:
|
||||||
@ -273,10 +284,7 @@ class Mamba(Model):
|
|||||||
decode_ns = time.time_ns() - start_decode
|
decode_ns = time.time_ns() - start_decode
|
||||||
return generations, None, (forward_ns, decode_ns)
|
return generations, None, (forward_ns, decode_ns)
|
||||||
|
|
||||||
# Slice unused values from prefill
|
|
||||||
batch.input_ids = batch.input_ids[:, :1]
|
|
||||||
batch.past_input_ids = past_input_ids
|
batch.past_input_ids = past_input_ids
|
||||||
batch.past_transformed_states = past_transformed_states
|
|
||||||
|
|
||||||
forward_ns = start_decode - start
|
forward_ns = start_decode - start
|
||||||
decode_ns = time.time_ns() - start_decode
|
decode_ns = time.time_ns() - start_decode
|
||||||
|
Loading…
Reference in New Issue
Block a user