From 5b6f9259c185b96a9e69844208768a76e59bf20b Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 30 Jan 2024 18:53:28 +0000 Subject: [PATCH] feat: add optimization and first pass of integration test --- .../test_fused_kernel_mamba.json | 84 +++++++++ .../test_fused_kernel_mamba_all_params.json | 99 ++++++++++ .../models/test_fused_kernel_mamba.py | 62 +++++++ .../models/custom_modeling/mamba_modeling.py | 147 ++++----------- server/text_generation_server/models/mamba.py | 172 +++++++++--------- 5 files changed, 372 insertions(+), 192 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_fused_kernel_mamba/test_fused_kernel_mamba.json create mode 100644 integration-tests/models/__snapshots__/test_fused_kernel_mamba/test_fused_kernel_mamba_all_params.json create mode 100644 integration-tests/models/test_fused_kernel_mamba.py diff --git a/integration-tests/models/__snapshots__/test_fused_kernel_mamba/test_fused_kernel_mamba.json b/integration-tests/models/__snapshots__/test_fused_kernel_mamba/test_fused_kernel_mamba.json new file mode 100644 index 00000000..ae6ee35e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_fused_kernel_mamba/test_fused_kernel_mamba.json @@ -0,0 +1,84 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 5089, + "logprob": null, + "text": "Test" + }, + { + "id": 2748, + "logprob": -9.7421875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -2.4824219, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -2.4824219, + "special": false, + "text": "\n" + }, + { + "id": 50274, + "logprob": -1.7880859, + "special": false, + "text": " " + }, + { + "id": 92, + "logprob": -2.0703125, + "special": false, + "text": "{" + }, + { + "id": 187, + "logprob": -0.04827881, + "special": false, + "text": "\n" + }, + { + "id": 50270, + "logprob": -0.18896484, + "special": false, + "text": " " + }, + { + "id": 3, + "logprob": -1.5234375, + "special": false, + "text": "\"" + }, + { + "id": 9629, + "logprob": -2.8203125, + "special": false, + "text": "request" + }, + { + "id": 1381, + "logprob": -0.78759766, + "special": false, + "text": "\":" + }, + { + "id": 551, + "logprob": -0.49169922, + "special": false, + "text": " {" + } + ], + "top_tokens": null + }, + "generated_text": "\n\n {\n \"request\": {" +} diff --git a/integration-tests/models/__snapshots__/test_fused_kernel_mamba/test_fused_kernel_mamba_all_params.json b/integration-tests/models/__snapshots__/test_fused_kernel_mamba/test_fused_kernel_mamba_all_params.json new file mode 100644 index 00000000..0ab7cf11 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_fused_kernel_mamba/test_fused_kernel_mamba_all_params.json @@ -0,0 +1,99 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2502, + "logprob": null, + "text": " red" + }, + { + "id": 13, + "logprob": -2.5234375, + "text": "," + }, + { + "id": 8862, + "logprob": -3.4746094, + "text": " yellow" + }, + { + "id": 13, + "logprob": -0.43579102, + "text": "," + }, + { + "id": 209, + "logprob": -8.2421875, + "text": " " + } + ], + "seed": 0, + "tokens": [ + { + "id": 187, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 2764, + "logprob": -0.37573242, + "special": false, + "text": "umber" + }, + { + "id": 285, + "logprob": 0.0, + "special": false, + "text": " and" + }, + { + "id": 3168, + "logprob": -0.9013672, + "special": false, + "text": " white" + }, + { + "id": 28, + "logprob": -1.2314453, + "special": false, + "text": ";" + }, + { + "id": 253, + "logprob": 0.0, + "special": false, + "text": " the" + }, + { + "id": 3295, + "logprob": -1.2167969, + "special": false, + "text": " color" + }, + { + "id": 273, + "logprob": 0.0, + "special": false, + "text": " of" + }, + { + "id": 697, + "logprob": -2.1015625, + "special": false, + "text": " its" + }, + { + "id": 17433, + "logprob": -2.4296875, + "special": false, + "text": " unders" + } + ], + "top_tokens": null + }, + "generated_text": "blue, red, yellow, \number and white; the color of its unders" +} diff --git a/integration-tests/models/test_fused_kernel_mamba.py b/integration-tests/models/test_fused_kernel_mamba.py new file mode 100644 index 00000000..0a449332 --- /dev/null +++ b/integration-tests/models/test_fused_kernel_mamba.py @@ -0,0 +1,62 @@ +import pytest + + +@pytest.fixture(scope="module") +def fused_kernel_mamba_handle(launcher): + with launcher("state-spaces/mamba-130m", num_shard=1) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def fused_kernel_mamba(fused_kernel_mamba_handle): + await fused_kernel_mamba_handle.health(300) + return fused_kernel_mamba_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_fused_kernel_mamba(fused_kernel_mamba, response_snapshot): + response = await fused_kernel_mamba.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_fused_kernel_mamba_all_params(fused_kernel_mamba, response_snapshot): + response = await fused_kernel_mamba.generate( + "blue, red, yellow, ", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + # TODO: fix so the input is not included in the output + assert response.generated_text == "blue, red, yellow, \number and white; the color of its unders" + assert response == response_snapshot + +# TODO: fix `Expected x0.dim() == 2 to be true, but got false.` +# 94: `hidden_states, _ = self.layer_norm(hidden_states.squeeze(0))` +# NOTE: the fast layer norm has strict requirements on the input shape +# @pytest.mark.asyncio +# @pytest.mark.private +# async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot): +# responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4) + +# assert len(responses) == 4 +# assert all([r.generated_text == responses[0].generated_text for r in responses]) + +# assert responses == response_snapshot diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py index 8df5f4df..ca5f9765 100644 --- a/server/text_generation_server/models/custom_modeling/mamba_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -5,12 +5,12 @@ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inne from torch import nn from typing import Optional, List, Tuple, Any from transformers.configuration_utils import PretrainedConfig -from transformers.modeling_outputs import CausalLMOutputWithPast import torch.nn.functional as F from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, + FastRMSNorm, ) class MambaConfig(PretrainedConfig): @@ -41,151 +41,78 @@ class MambaConfig(PretrainedConfig): **kwargs, ) - class MambaBlock(nn.Module): def __init__(self, prefix, config, weights): super().__init__() - self.dt_rank = (config.d_model + 15) // 16 - self.x_proj_weight = weights.get_tensor(f"{prefix}.x_proj.weight") - self.dt_proj_weight = weights.get_tensor(f"{prefix}.dt_proj.weight") - self.dt_proj_bias = weights.get_tensor(f"{prefix}.dt_proj.bias") - self.out_proj_weight = weights.get_tensor(f"{prefix}.out_proj.weight") - self.out_proj_bias = None - # TODO: avoid loading the same weights twice - self.in_proj_weight = weights.get_tensor(f"{prefix}.in_proj.weight") - self.in_proj_bias = None self.in_proj = TensorParallelColumnLinear.load( - config=config, - prefix=f"{prefix}.in_proj", - weights=weights, - bias=False, + config=config, prefix=f"{prefix}.in_proj", weights=weights, bias=False ) - self.conv1d = nn.Conv1d( - config.d_inner, - config.d_inner, - kernel_size=config.d_conv, - groups=config.d_inner, - padding=config.d_conv - 1, - ) - self.conv1d.weight = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.weight")) - self.conv1d.bias = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.bias")) - self.A_log = nn.Parameter(weights.get_tensor(f"{prefix}.A_log")) - self.D = nn.Parameter(weights.get_tensor(f"{prefix}.D")) + # helper for loading weights + self.load_weights(prefix, weights) - def forward(self, index, hidden_states, past_transformed_state): - projected_states = self.in_proj(hidden_states) - - A = -torch.exp(self.A_log.float()) + def load_weights(self, prefix, weights): + weight_names = ["x_proj.weight", "dt_proj.weight", "dt_proj.bias", + "out_proj.weight", "in_proj.weight", + "conv1d.weight", "conv1d.bias", "A_log", "D"] + for name in weight_names: + param_name = name.replace('.', '_') + setattr(self, param_name, nn.Parameter(weights.get_tensor(f"{prefix}.{name}"))) + self.out_proj_bias = None + self.negA = -torch.exp(self.A_log.float()) + def forward(self, hidden_states: torch.Tensor): + projected_states = self.in_proj(hidden_states).transpose(1,2) # conv1d, ssm, and selective_scan are all fused into one kernel attn_outputs = mamba_inner_fn( - projected_states.transpose(1,2), - self.conv1d.weight, - self.conv1d.bias, + projected_states, + self.conv1d_weight, + self.conv1d_bias, self.x_proj_weight, self.dt_proj_weight, self.out_proj_weight, self.out_proj_bias, - A, + self.negA, None, None, self.D.float(), delta_bias=self.dt_proj_bias.float(), delta_softplus=True, ) - - return attn_outputs, projected_states - - -# TODO: prefer a more optimized implementation of RMSNorm if possible -class RMSNorm(nn.Module): - def __init__(self, num_features, eps=1e-8): - super().__init__() - self.num_features = num_features - self.eps = eps - self.scale = nn.Parameter(torch.ones(num_features)) - - def forward(self, x): - rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps) - x = x / rms - return self.scale * x - + return attn_outputs class ResidualBlock(nn.Module): def __init__(self, layer_id, config, weights): super().__init__() - self.layer_id = layer_id - self.mamba_block = MambaBlock( - prefix=f"{layer_id}.mixer", config=config, weights=weights - ) - self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) - self.layer_norm.scale = nn.Parameter( - weights.get_tensor(f"{layer_id}.norm.weight") - ) + self.mamba_block = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights) + self.layer_norm = FastRMSNorm.load(prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon) def forward( self, - index, - hidden_states, - past_transformed_state, + hidden_states: torch.Tensor, ): residual = hidden_states - hidden_states = self.layer_norm(hidden_states) - attn_outputs, transformed_states = self.mamba_block( - index, hidden_states, past_transformed_state - ) - hidden_states = residual + attn_outputs - return hidden_states, transformed_states - + hidden_states, _ = self.layer_norm(hidden_states.squeeze(0)) + hidden_states = residual + self.mamba_block(hidden_states.unsqueeze(0)) + return hidden_states class MambaModel(nn.Module): def __init__(self, config, weights): super().__init__() self.tp_rank = weights.process_group.rank() self.tp_world_size = weights.process_group.size() - self.embed_tokens = TensorParallelEmbedding( - prefix="backbone.embedding", weights=weights - ) + prefix = "backbone" + + self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) self.blocks = nn.ModuleList( - [ - ResidualBlock(f"backbone.layers.{layer_id}", config, weights) - for layer_id in range(config.n_layer) - ] + [ResidualBlock(f"{prefix}.layers.{i}", config, weights) for i in range(config.n_layer)] ) + self.norm_f = FastRMSNorm.load(f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon) + self.lm_head = TensorParallelColumnLinear.load(config, f"{prefix}.embedding", weights, False) - # TODO: avoid hardcoded sizes and improve how we load the weights - self.norm_f = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) - self.norm_f.scale = nn.Parameter(weights.get_tensor(f"backbone.norm_f.weight")) - # use the same weights for the embedding and the final layer norm - self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) - self.lm_head.weight = nn.Parameter( - self.embed_tokens.weight[: config.vocab_size, :] - ) - - def forward( - self, - input_ids: torch.LongTensor, - past_input_ids: Optional[List[Tuple[torch.FloatTensor]]] = None, - past_transformed_states: Optional[List[Tuple[torch.FloatTensor]]] = None, - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - # NOTE: we need all input_ids to compute the correct embeddings - if past_input_ids is not None: - input_ids = past_input_ids - + def forward(self, input_ids: torch.Tensor): hidden_states = self.embed_tokens(input_ids) + for block in self.blocks: + hidden_states = block(hidden_states) - past_transformed_states = ( - [None] * len(self.blocks) - if past_transformed_states is None - else past_transformed_states - ) - - for index, block in enumerate(self.blocks): - hidden_states, transformed_states = block( - index, hidden_states, past_transformed_states[index] - ) - past_transformed_states[index] = transformed_states - - final_hidden_states = self.norm_f(hidden_states) - after_lm_head = self.lm_head(final_hidden_states) - return after_lm_head, input_ids, past_transformed_states + final_hidden_states, _ = self.norm_f(hidden_states.squeeze(0)) + return self.lm_head(final_hidden_states.unsqueeze(0)), input_ids diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 8b890426..05a5b99e 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -30,12 +30,11 @@ from text_generation_server.utils.tokens import batch_top_tokens, Sampling class MambaCausalLMBatch(CausalLMBatch): - past_transformed_states: Optional[List[torch.Tensor]] + past_input_ids: Optional[torch.Tensor] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.past_input_ids = None - self.past_transformed_states = None @classmethod def from_pb( @@ -103,6 +102,10 @@ class Mamba(Model): def batch_type(self) -> Type[CausalLMBatch]: return MambaCausalLMBatch + def warmup(self, batch) -> Optional[int]: + # TODO: implement warmup for Mamba if needed + return None + def forward( self, input_ids: torch.Tensor, @@ -116,19 +119,9 @@ class Mamba(Model): def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]: start = time.time_ns() - input_ids = batch.input_ids - past_input_ids = batch.past_input_ids - past_transformed_states = batch.past_transformed_states + input_ids = batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids - model_output = self.model( - input_ids, - past_input_ids, - past_transformed_states, - ) - - logits = model_output[0] - past_input_ids = model_output[1] - past_transformed_states = model_output[2] + logits, past_input_ids = self.model(input_ids)[:2] # Results generations: List[Generation] = [] @@ -176,9 +169,6 @@ class Mamba(Model): all_input_ids.view(1, -1), logits[-1:, :] ) - # add next token to past_input_ids - past_input_ids = torch.cat([past_input_ids, next_token_id], dim=1) - # Append next token to all tokens all_input_ids = torch.cat([all_input_ids, next_token_id]) new_input_length = input_length + 1 @@ -199,73 +189,94 @@ class Mamba(Model): if not stop: stopped = False - if stop: - # Decode generated tokens - output_text, _, _ = self.decode_token( - all_input_ids[:, 0], - prefix_offset=len(all_input_ids) - - stopping_criteria.current_tokens - - 1, - read_offset=len(all_input_ids) - stopping_criteria.current_tokens, - skip_special_tokens=True, - ) - # Get seed - if isinstance(next_token_chooser.choice, Sampling): - seed = next_token_chooser.choice.seed + # Shard generations + # All generations will be appended in the rust sharded client + if i % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text, _, _ = self.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed + ) else: - seed = None + generated_text = None - generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason, seed + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: + # Remove generated token to only have prefill and add nan for first prompt token + prefill_logprobs = [float("nan")] + torch.log_softmax( + logits, -1 + ).gather(1, all_input_ids[1:]).squeeze(1)[ + -new_input_length:-1 + ].tolist() + prefill_token_ids = all_input_ids[-new_input_length:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = Tokens( + prefill_token_ids, + prefill_logprobs, + prefill_texts, + is_special=[], + ) + else: + prefill_tokens = None + past_input_ids = torch.cat([past_input_ids, next_token_id], dim=1) + + + if top_n_tokens > 0: + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + else: + top_tokens = None + + generation = Generation( + batch.batch_id, + prefill_tokens, + Tokens( + [next_token_id_squeezed], + [next_token_logprob], + [next_token_text], + [next_token_id_squeezed.item() in self.all_special_ids], + ), + generated_text, + top_tokens, ) - else: - generated_text = None - # Prefill - if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + torch.log_softmax( - logits, -1 - ).gather(1, all_input_ids[1:]).squeeze(1)[-new_input_length:-1].tolist() - prefill_token_ids = all_input_ids[-new_input_length:-1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = Tokens( - prefill_token_ids, - prefill_logprobs, - prefill_texts, - is_special=[], - ) - else: - prefill_tokens = None + generations.append(generation) - generation = Generation( - batch.batch_id, - None, - Tokens( - [next_token_id_squeezed], - [next_token_logprob], - [next_token_text], - [next_token_id_squeezed.item() in self.all_special_ids], - ), - generated_text, - None, - ) - - generations.append(generation) - next_token_tensor = next_token_id_squeezed.view(1, 1) - # Update values - batch.input_ids = torch.cat( - [batch.input_ids, next_token_tensor], dim=1 - ) - batch.all_input_ids[i] = all_input_ids - batch.input_lengths[i] = new_input_length - batch.prefix_offsets[i] = prefix_offset - batch.read_offsets[i] = read_offset - batch.max_input_length = max(batch.max_input_length, new_input_length) + # Update values + batch.all_input_ids[i] = all_input_ids + batch.input_lengths[i] = new_input_length + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset + batch.max_input_length = max(batch.max_input_length, new_input_length) # We finished all generations in the batch; there is no next batch if stopped: @@ -273,10 +284,7 @@ class Mamba(Model): decode_ns = time.time_ns() - start_decode return generations, None, (forward_ns, decode_ns) - # Slice unused values from prefill - batch.input_ids = batch.input_ids[:, :1] batch.past_input_ids = past_input_ids - batch.past_transformed_states = past_transformed_states forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode