feat: add optimization and first pass of integration test

This commit is contained in:
drbh 2024-01-30 18:53:28 +00:00
parent 966f3ba35c
commit 5b6f9259c1
5 changed files with 372 additions and 192 deletions

View File

@ -0,0 +1,84 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 5089,
"logprob": null,
"text": "Test"
},
{
"id": 2748,
"logprob": -9.7421875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 187,
"logprob": -2.4824219,
"special": false,
"text": "\n"
},
{
"id": 187,
"logprob": -2.4824219,
"special": false,
"text": "\n"
},
{
"id": 50274,
"logprob": -1.7880859,
"special": false,
"text": " "
},
{
"id": 92,
"logprob": -2.0703125,
"special": false,
"text": "{"
},
{
"id": 187,
"logprob": -0.04827881,
"special": false,
"text": "\n"
},
{
"id": 50270,
"logprob": -0.18896484,
"special": false,
"text": " "
},
{
"id": 3,
"logprob": -1.5234375,
"special": false,
"text": "\""
},
{
"id": 9629,
"logprob": -2.8203125,
"special": false,
"text": "request"
},
{
"id": 1381,
"logprob": -0.78759766,
"special": false,
"text": "\":"
},
{
"id": 551,
"logprob": -0.49169922,
"special": false,
"text": " {"
}
],
"top_tokens": null
},
"generated_text": "\n\n {\n \"request\": {"
}

View File

@ -0,0 +1,99 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2502,
"logprob": null,
"text": " red"
},
{
"id": 13,
"logprob": -2.5234375,
"text": ","
},
{
"id": 8862,
"logprob": -3.4746094,
"text": " yellow"
},
{
"id": 13,
"logprob": -0.43579102,
"text": ","
},
{
"id": 209,
"logprob": -8.2421875,
"text": " "
}
],
"seed": 0,
"tokens": [
{
"id": 187,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 2764,
"logprob": -0.37573242,
"special": false,
"text": "umber"
},
{
"id": 285,
"logprob": 0.0,
"special": false,
"text": " and"
},
{
"id": 3168,
"logprob": -0.9013672,
"special": false,
"text": " white"
},
{
"id": 28,
"logprob": -1.2314453,
"special": false,
"text": ";"
},
{
"id": 253,
"logprob": 0.0,
"special": false,
"text": " the"
},
{
"id": 3295,
"logprob": -1.2167969,
"special": false,
"text": " color"
},
{
"id": 273,
"logprob": 0.0,
"special": false,
"text": " of"
},
{
"id": 697,
"logprob": -2.1015625,
"special": false,
"text": " its"
},
{
"id": 17433,
"logprob": -2.4296875,
"special": false,
"text": " unders"
}
],
"top_tokens": null
},
"generated_text": "blue, red, yellow, \number and white; the color of its unders"
}

View File

@ -0,0 +1,62 @@
import pytest
@pytest.fixture(scope="module")
def fused_kernel_mamba_handle(launcher):
with launcher("state-spaces/mamba-130m", num_shard=1) as handle:
yield handle
@pytest.fixture(scope="module")
async def fused_kernel_mamba(fused_kernel_mamba_handle):
await fused_kernel_mamba_handle.health(300)
return fused_kernel_mamba_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_fused_kernel_mamba(fused_kernel_mamba, response_snapshot):
response = await fused_kernel_mamba.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_fused_kernel_mamba_all_params(fused_kernel_mamba, response_snapshot):
response = await fused_kernel_mamba.generate(
"blue, red, yellow, ",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
# TODO: fix so the input is not included in the output
assert response.generated_text == "blue, red, yellow, \number and white; the color of its unders"
assert response == response_snapshot
# TODO: fix `Expected x0.dim() == 2 to be true, but got false.`
# 94: `hidden_states, _ = self.layer_norm(hidden_states.squeeze(0))`
# NOTE: the fast layer norm has strict requirements on the input shape
# @pytest.mark.asyncio
# @pytest.mark.private
# async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
# responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4)
# assert len(responses) == 4
# assert all([r.generated_text == responses[0].generated_text for r in responses])
# assert responses == response_snapshot

View File

@ -5,12 +5,12 @@ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inne
from torch import nn
from typing import Optional, List, Tuple, Any
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
import torch.nn.functional as F
from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
FastRMSNorm,
)
class MambaConfig(PretrainedConfig):
@ -41,151 +41,78 @@ class MambaConfig(PretrainedConfig):
**kwargs,
)
class MambaBlock(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.dt_rank = (config.d_model + 15) // 16
self.x_proj_weight = weights.get_tensor(f"{prefix}.x_proj.weight")
self.dt_proj_weight = weights.get_tensor(f"{prefix}.dt_proj.weight")
self.dt_proj_bias = weights.get_tensor(f"{prefix}.dt_proj.bias")
self.out_proj_weight = weights.get_tensor(f"{prefix}.out_proj.weight")
self.out_proj_bias = None
# TODO: avoid loading the same weights twice
self.in_proj_weight = weights.get_tensor(f"{prefix}.in_proj.weight")
self.in_proj_bias = None
self.in_proj = TensorParallelColumnLinear.load(
config=config,
prefix=f"{prefix}.in_proj",
weights=weights,
bias=False,
config=config, prefix=f"{prefix}.in_proj", weights=weights, bias=False
)
self.conv1d = nn.Conv1d(
config.d_inner,
config.d_inner,
kernel_size=config.d_conv,
groups=config.d_inner,
padding=config.d_conv - 1,
)
self.conv1d.weight = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.weight"))
self.conv1d.bias = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.bias"))
self.A_log = nn.Parameter(weights.get_tensor(f"{prefix}.A_log"))
self.D = nn.Parameter(weights.get_tensor(f"{prefix}.D"))
# helper for loading weights
self.load_weights(prefix, weights)
def forward(self, index, hidden_states, past_transformed_state):
projected_states = self.in_proj(hidden_states)
A = -torch.exp(self.A_log.float())
def load_weights(self, prefix, weights):
weight_names = ["x_proj.weight", "dt_proj.weight", "dt_proj.bias",
"out_proj.weight", "in_proj.weight",
"conv1d.weight", "conv1d.bias", "A_log", "D"]
for name in weight_names:
param_name = name.replace('.', '_')
setattr(self, param_name, nn.Parameter(weights.get_tensor(f"{prefix}.{name}")))
self.out_proj_bias = None
self.negA = -torch.exp(self.A_log.float())
def forward(self, hidden_states: torch.Tensor):
projected_states = self.in_proj(hidden_states).transpose(1,2)
# conv1d, ssm, and selective_scan are all fused into one kernel
attn_outputs = mamba_inner_fn(
projected_states.transpose(1,2),
self.conv1d.weight,
self.conv1d.bias,
projected_states,
self.conv1d_weight,
self.conv1d_bias,
self.x_proj_weight,
self.dt_proj_weight,
self.out_proj_weight,
self.out_proj_bias,
A,
self.negA,
None,
None,
self.D.float(),
delta_bias=self.dt_proj_bias.float(),
delta_softplus=True,
)
return attn_outputs, projected_states
# TODO: prefer a more optimized implementation of RMSNorm if possible
class RMSNorm(nn.Module):
def __init__(self, num_features, eps=1e-8):
super().__init__()
self.num_features = num_features
self.eps = eps
self.scale = nn.Parameter(torch.ones(num_features))
def forward(self, x):
rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
x = x / rms
return self.scale * x
return attn_outputs
class ResidualBlock(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
self.layer_id = layer_id
self.mamba_block = MambaBlock(
prefix=f"{layer_id}.mixer", config=config, weights=weights
)
self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.layer_norm.scale = nn.Parameter(
weights.get_tensor(f"{layer_id}.norm.weight")
)
self.mamba_block = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights)
self.layer_norm = FastRMSNorm.load(prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon)
def forward(
self,
index,
hidden_states,
past_transformed_state,
hidden_states: torch.Tensor,
):
residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
attn_outputs, transformed_states = self.mamba_block(
index, hidden_states, past_transformed_state
)
hidden_states = residual + attn_outputs
return hidden_states, transformed_states
hidden_states, _ = self.layer_norm(hidden_states.squeeze(0))
hidden_states = residual + self.mamba_block(hidden_states.unsqueeze(0))
return hidden_states
class MambaModel(nn.Module):
def __init__(self, config, weights):
super().__init__()
self.tp_rank = weights.process_group.rank()
self.tp_world_size = weights.process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="backbone.embedding", weights=weights
)
prefix = "backbone"
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
self.blocks = nn.ModuleList(
[
ResidualBlock(f"backbone.layers.{layer_id}", config, weights)
for layer_id in range(config.n_layer)
]
[ResidualBlock(f"{prefix}.layers.{i}", config, weights) for i in range(config.n_layer)]
)
self.norm_f = FastRMSNorm.load(f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon)
self.lm_head = TensorParallelColumnLinear.load(config, f"{prefix}.embedding", weights, False)
# TODO: avoid hardcoded sizes and improve how we load the weights
self.norm_f = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.norm_f.scale = nn.Parameter(weights.get_tensor(f"backbone.norm_f.weight"))
# use the same weights for the embedding and the final layer norm
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.lm_head.weight = nn.Parameter(
self.embed_tokens.weight[: config.vocab_size, :]
)
def forward(
self,
input_ids: torch.LongTensor,
past_input_ids: Optional[List[Tuple[torch.FloatTensor]]] = None,
past_transformed_states: Optional[List[Tuple[torch.FloatTensor]]] = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# NOTE: we need all input_ids to compute the correct embeddings
if past_input_ids is not None:
input_ids = past_input_ids
def forward(self, input_ids: torch.Tensor):
hidden_states = self.embed_tokens(input_ids)
for block in self.blocks:
hidden_states = block(hidden_states)
past_transformed_states = (
[None] * len(self.blocks)
if past_transformed_states is None
else past_transformed_states
)
for index, block in enumerate(self.blocks):
hidden_states, transformed_states = block(
index, hidden_states, past_transformed_states[index]
)
past_transformed_states[index] = transformed_states
final_hidden_states = self.norm_f(hidden_states)
after_lm_head = self.lm_head(final_hidden_states)
return after_lm_head, input_ids, past_transformed_states
final_hidden_states, _ = self.norm_f(hidden_states.squeeze(0))
return self.lm_head(final_hidden_states.unsqueeze(0)), input_ids

View File

@ -30,12 +30,11 @@ from text_generation_server.utils.tokens import batch_top_tokens, Sampling
class MambaCausalLMBatch(CausalLMBatch):
past_transformed_states: Optional[List[torch.Tensor]]
past_input_ids: Optional[torch.Tensor]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.past_input_ids = None
self.past_transformed_states = None
@classmethod
def from_pb(
@ -103,6 +102,10 @@ class Mamba(Model):
def batch_type(self) -> Type[CausalLMBatch]:
return MambaCausalLMBatch
def warmup(self, batch) -> Optional[int]:
# TODO: implement warmup for Mamba if needed
return None
def forward(
self,
input_ids: torch.Tensor,
@ -116,19 +119,9 @@ class Mamba(Model):
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
start = time.time_ns()
input_ids = batch.input_ids
past_input_ids = batch.past_input_ids
past_transformed_states = batch.past_transformed_states
input_ids = batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
model_output = self.model(
input_ids,
past_input_ids,
past_transformed_states,
)
logits = model_output[0]
past_input_ids = model_output[1]
past_transformed_states = model_output[2]
logits, past_input_ids = self.model(input_ids)[:2]
# Results
generations: List[Generation] = []
@ -176,9 +169,6 @@ class Mamba(Model):
all_input_ids.view(1, -1), logits[-1:, :]
)
# add next token to past_input_ids
past_input_ids = torch.cat([past_input_ids, next_token_id], dim=1)
# Append next token to all tokens
all_input_ids = torch.cat([all_input_ids, next_token_id])
new_input_length = input_length + 1
@ -199,6 +189,9 @@ class Mamba(Model):
if not stop:
stopped = False
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text, _, _ = self.decode_token(
@ -221,12 +214,13 @@ class Mamba(Model):
else:
generated_text = None
# Prefill
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + torch.log_softmax(
logits, -1
).gather(1, all_input_ids[1:]).squeeze(1)[-new_input_length:-1].tolist()
).gather(1, all_input_ids[1:]).squeeze(1)[
-new_input_length:-1
].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
@ -241,10 +235,30 @@ class Mamba(Model):
)
else:
prefill_tokens = None
past_input_ids = torch.cat([past_input_ids, next_token_id], dim=1)
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
else:
top_tokens = None
generation = Generation(
batch.batch_id,
None,
prefill_tokens,
Tokens(
[next_token_id_squeezed],
[next_token_logprob],
@ -252,15 +266,12 @@ class Mamba(Model):
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
None,
top_tokens,
)
generations.append(generation)
next_token_tensor = next_token_id_squeezed.view(1, 1)
# Update values
batch.input_ids = torch.cat(
[batch.input_ids, next_token_tensor], dim=1
)
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length
batch.prefix_offsets[i] = prefix_offset
@ -273,10 +284,7 @@ class Mamba(Model):
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
# Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1]
batch.past_input_ids = past_input_ids
batch.past_transformed_states = past_transformed_states
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode