diff --git a/integration-tests/models/__snapshots__/test_mamba/test_mamba.json b/integration-tests/models/__snapshots__/test_mamba/test_mamba.json index 4435f215..eaba5078 100644 --- a/integration-tests/models/__snapshots__/test_mamba/test_mamba.json +++ b/integration-tests/models/__snapshots__/test_mamba/test_mamba.json @@ -8,61 +8,61 @@ "tokens": [ { "id": 187, - "logprob": -0.3552246, + "logprob": -0.37890625, "special": false, "text": "\n" }, { "id": 187, - "logprob": -0.38378906, + "logprob": -0.26953125, "special": false, "text": "\n" }, { "id": 30763, - "logprob": -1.140625, + "logprob": -1.1953125, "special": false, "text": "Deep" }, { "id": 4715, - "logprob": -0.5551758, + "logprob": -0.53515625, "special": false, "text": " learning" }, { "id": 310, - "logprob": -0.59033203, + "logprob": -0.625, "special": false, "text": " is" }, { "id": 247, - "logprob": -0.70654297, + "logprob": -0.6796875, "special": false, "text": " a" }, { "id": 747, - "logprob": -2.0410156, + "logprob": -2.0, "special": false, "text": " new" }, { "id": 1511, - "logprob": -2.3789062, + "logprob": -2.3125, "special": false, "text": " type" }, { "id": 273, - "logprob": -0.0026435852, + "logprob": -0.0028533936, "special": false, "text": " of" }, { "id": 5145, - "logprob": -1.2841797, + "logprob": -1.265625, "special": false, "text": " machine" } diff --git a/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json b/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json index 052c1c69..85e9a9e0 100644 --- a/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json +++ b/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json @@ -11,22 +11,22 @@ }, { "id": 13, - "logprob": -2.5234375, + "logprob": -2.734375, "text": "," }, { "id": 8862, - "logprob": -3.4433594, + "logprob": -3.6875, "text": " yellow" }, { "id": 13, - "logprob": -0.43017578, + "logprob": -0.40234375, "text": "," }, { "id": 209, - "logprob": -8.21875, + "logprob": -8.25, "text": " " } ], @@ -40,60 +40,60 @@ }, { "id": 395, - "logprob": -0.46411133, + "logprob": -0.3125, "special": false, "text": "and" }, { - "id": 13735, - "logprob": -2.1132812, - "special": false, - "text": " orange" - }, - { - "id": 313, - "logprob": -1.2128906, - "special": false, - "text": " (" - }, - { - "id": 249, - "logprob": -2.3671875, - "special": false, - "text": "in" - }, - { - "id": 253, + "id": 4797, "logprob": 0.0, "special": false, - "text": " the" + "text": " blue" }, { - "id": 1340, - "logprob": -1.640625, + "id": 9830, + "logprob": -1.65625, "special": false, - "text": " order" + "text": " colors" }, { - "id": 597, - "logprob": -0.5488281, - "special": false, - "text": " they" - }, - { - "id": 3176, - "logprob": -0.48608398, - "special": false, - "text": " appear" - }, - { - "id": 275, + "id": 15, "logprob": 0.0, "special": false, - "text": " in" + "text": "." + }, + { + "id": 329, + "logprob": -2.4375, + "special": false, + "text": " A" + }, + { + "id": 1180, + "logprob": -1.953125, + "special": false, + "text": " number" + }, + { + "id": 273, + "logprob": 0.0, + "special": false, + "text": " of" + }, + { + "id": 1027, + "logprob": -1.5546875, + "special": false, + "text": " different" + }, + { + "id": 3295, + "logprob": -0.97265625, + "special": false, + "text": " color" } ], "top_tokens": null }, - "generated_text": "blue, red, yellow, \nand orange (in the order they appear in" + "generated_text": "blue, red, yellow, \nand blue colors. A number of different color" } diff --git a/integration-tests/models/__snapshots__/test_mamba/test_mamba_load.json b/integration-tests/models/__snapshots__/test_mamba/test_mamba_load.json index 014210b2..4921c14b 100644 --- a/integration-tests/models/__snapshots__/test_mamba/test_mamba_load.json +++ b/integration-tests/models/__snapshots__/test_mamba/test_mamba_load.json @@ -12,22 +12,22 @@ }, { "id": 310, - "logprob": -0.8125, + "logprob": -0.83984375, "text": " is" }, { "id": 18147, - "logprob": -12.828125, + "logprob": -12.8125, "text": " Deep" }, { "id": 20727, - "logprob": -3.0, + "logprob": -2.84375, "text": " Learning" }, { "id": 32, - "logprob": -1.1484375, + "logprob": -1.25, "text": "?" } ], @@ -35,61 +35,61 @@ "tokens": [ { "id": 187, - "logprob": -0.3552246, + "logprob": -0.37890625, "special": false, "text": "\n" }, { "id": 187, - "logprob": -0.38378906, + "logprob": -0.4296875, "special": false, "text": "\n" }, { "id": 30763, - "logprob": -1.1279297, + "logprob": -1.078125, "special": false, "text": "Deep" }, { "id": 4715, - "logprob": -0.5595703, + "logprob": -0.515625, "special": false, "text": " learning" }, { "id": 310, - "logprob": -0.60253906, + "logprob": -0.6015625, "special": false, "text": " is" }, { "id": 247, - "logprob": -0.7050781, + "logprob": -0.65625, "special": false, "text": " a" }, { "id": 747, - "logprob": -2.0488281, + "logprob": -2.109375, "special": false, "text": " new" }, { "id": 1511, - "logprob": -2.3808594, + "logprob": -2.328125, "special": false, "text": " type" }, { "id": 273, - "logprob": -0.0026416779, + "logprob": -0.0032653809, "special": false, "text": " of" }, { "id": 5145, - "logprob": -1.2851562, + "logprob": -1.28125, "special": false, "text": " machine" } @@ -111,22 +111,22 @@ }, { "id": 310, - "logprob": -0.78027344, + "logprob": -0.80078125, "text": " is" }, { "id": 18147, - "logprob": -12.8203125, + "logprob": -13.25, "text": " Deep" }, { "id": 20727, - "logprob": -2.9902344, + "logprob": -2.828125, "text": " Learning" }, { "id": 32, - "logprob": -1.1523438, + "logprob": -1.1953125, "text": "?" } ], @@ -134,61 +134,61 @@ "tokens": [ { "id": 187, - "logprob": -0.35351562, + "logprob": -0.296875, "special": false, "text": "\n" }, { "id": 187, - "logprob": -0.38256836, + "logprob": -0.3359375, "special": false, "text": "\n" }, { "id": 30763, - "logprob": -1.1269531, + "logprob": -1.2578125, "special": false, "text": "Deep" }, { "id": 4715, - "logprob": -0.54541016, + "logprob": -0.5546875, "special": false, "text": " learning" }, { "id": 310, - "logprob": -0.59765625, + "logprob": -0.62890625, "special": false, "text": " is" }, { "id": 247, - "logprob": -0.7001953, + "logprob": -0.64453125, "special": false, "text": " a" }, { "id": 747, - "logprob": -2.0585938, + "logprob": -2.078125, "special": false, "text": " new" }, { "id": 1511, - "logprob": -2.3789062, + "logprob": -2.28125, "special": false, "text": " type" }, { "id": 273, - "logprob": -0.0027446747, + "logprob": -0.0030670166, "special": false, "text": " of" }, { "id": 5145, - "logprob": -1.2851562, + "logprob": -1.3125, "special": false, "text": " machine" } @@ -210,22 +210,22 @@ }, { "id": 310, - "logprob": -0.78027344, + "logprob": -0.80078125, "text": " is" }, { "id": 18147, - "logprob": -12.8203125, + "logprob": -13.25, "text": " Deep" }, { "id": 20727, - "logprob": -2.9902344, + "logprob": -2.828125, "text": " Learning" }, { "id": 32, - "logprob": -1.1523438, + "logprob": -1.1953125, "text": "?" } ], @@ -233,61 +233,61 @@ "tokens": [ { "id": 187, - "logprob": -0.35351562, + "logprob": -0.296875, "special": false, "text": "\n" }, { "id": 187, - "logprob": -0.38256836, + "logprob": -0.3359375, "special": false, "text": "\n" }, { "id": 30763, - "logprob": -1.1269531, + "logprob": -1.2578125, "special": false, "text": "Deep" }, { "id": 4715, - "logprob": -0.54541016, + "logprob": -0.5546875, "special": false, "text": " learning" }, { "id": 310, - "logprob": -0.59765625, + "logprob": -0.62890625, "special": false, "text": " is" }, { "id": 247, - "logprob": -0.7001953, + "logprob": -0.64453125, "special": false, "text": " a" }, { "id": 747, - "logprob": -2.0585938, + "logprob": -2.078125, "special": false, "text": " new" }, { "id": 1511, - "logprob": -2.3789062, + "logprob": -2.28125, "special": false, "text": " type" }, { "id": 273, - "logprob": -0.0027446747, + "logprob": -0.0030670166, "special": false, "text": " of" }, { "id": 5145, - "logprob": -1.2851562, + "logprob": -1.3125, "special": false, "text": " machine" } @@ -309,22 +309,22 @@ }, { "id": 310, - "logprob": -0.78027344, + "logprob": -0.80078125, "text": " is" }, { "id": 18147, - "logprob": -12.8203125, + "logprob": -13.25, "text": " Deep" }, { "id": 20727, - "logprob": -2.9902344, + "logprob": -2.828125, "text": " Learning" }, { "id": 32, - "logprob": -1.1523438, + "logprob": -1.1953125, "text": "?" } ], @@ -332,61 +332,61 @@ "tokens": [ { "id": 187, - "logprob": -0.35351562, + "logprob": -0.296875, "special": false, "text": "\n" }, { "id": 187, - "logprob": -0.38256836, + "logprob": -0.3359375, "special": false, "text": "\n" }, { "id": 30763, - "logprob": -1.1269531, + "logprob": -1.2578125, "special": false, "text": "Deep" }, { "id": 4715, - "logprob": -0.54541016, + "logprob": -0.5546875, "special": false, "text": " learning" }, { "id": 310, - "logprob": -0.59765625, + "logprob": -0.62890625, "special": false, "text": " is" }, { "id": 247, - "logprob": -0.7001953, + "logprob": -0.64453125, "special": false, "text": " a" }, { "id": 747, - "logprob": -2.0585938, + "logprob": -2.078125, "special": false, "text": " new" }, { "id": 1511, - "logprob": -2.3789062, + "logprob": -2.28125, "special": false, "text": " type" }, { "id": 273, - "logprob": -0.0027446747, + "logprob": -0.0030670166, "special": false, "text": " of" }, { "id": 5145, - "logprob": -1.2851562, + "logprob": -1.3125, "special": false, "text": " machine" } diff --git a/integration-tests/models/test_mamba.py b/integration-tests/models/test_mamba.py index bf398999..4cc863f0 100644 --- a/integration-tests/models/test_mamba.py +++ b/integration-tests/models/test_mamba.py @@ -47,14 +47,14 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): assert response.details.generated_tokens == 10 assert ( response.generated_text - == "blue, red, yellow, \nand orange (in the order they appear in" + == "blue, red, yellow, \nand blue colors. A number of different color" ) assert response == response_snapshot @pytest.mark.asyncio @pytest.mark.private -async def test_mamba_load(fused_kernel_mamba, generate_load, response_snapshot): +async def test_mamba_load(fused_kernel_mamba, generate_load, generous_response_snapshot): responses = await generate_load( fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4 ) @@ -63,4 +63,4 @@ async def test_mamba_load(fused_kernel_mamba, generate_load, response_snapshot): assert all([r.generated_text == responses[0].generated_text for r in responses]) assert responses[0].generated_text == "\n\nDeep learning is a new type of machine" - assert responses == response_snapshot + assert responses == generous_response_snapshot diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py index 017c0341..53e939bb 100644 --- a/server/text_generation_server/models/custom_modeling/mamba_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -3,7 +3,6 @@ import torch.distributed from mamba_ssm.ops.triton.selective_state_update import selective_state_update from mamba_ssm.ops.selective_scan_interface import selective_scan_fn -from mamba_ssm.utils.generation import InferenceParams from torch import nn from typing import Optional, Tuple, Any from transformers.configuration_utils import PretrainedConfig @@ -18,6 +17,17 @@ from text_generation_server.utils.layers import ( from einops import rearrange from causal_conv1d import causal_conv1d_fn, causal_conv1d_update import math +from dataclasses import dataclass + +@dataclass +class InferenceParams: + """Inference parameters that are passed to the main model in order + to efficienly calculate and store the context during inference.""" + max_seqlen: int + max_batch_size: int + conv_states: torch.Tensor + ssm_states: torch.Tensor + seqlen_offset: int class MambaConfig(PretrainedConfig): @@ -56,9 +66,9 @@ class MambaConfig(PretrainedConfig): class MambaBlock(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, layer_id): super().__init__() - self.layer_idx = int(prefix.split(".")[2]) + self.layer_id = layer_id self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False) self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False) self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True) @@ -79,21 +89,20 @@ class MambaBlock(nn.Module): # inference_params def forward(self, hidden_states: torch.Tensor, inference_params=None): - _, seqlen, _ = hidden_states.shape - conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] - if inference_params.seqlen_offset > 0: + conv_state = inference_params.conv_states[self.layer_id] + ssm_state = inference_params.ssm_states[self.layer_id] out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state) return out, conv_state, ssm_state + _, seqlen, _ = hidden_states.shape projected_states = self.in_proj(hidden_states).transpose(1, 2) + # assert projected_states.shape == [batch_size, 2 * dstate, seqlen], f"{projected_states.shape} [{batch_size}, {dstate}, {seqlen}]" x, z = projected_states.chunk(2, dim=1) conv_state = F.pad(x, (self.d_conv - seqlen, 0)) x = causal_conv1d_fn( x=x, - weight=self.conv1d.weight.view( - self.conv1d.weight.size(0), self.conv1d.weight.size(2) - ), + weight=self.conv1d.weight.squeeze(1), bias=self.conv1d.bias, activation=self.activation, ) @@ -126,56 +135,28 @@ class MambaBlock(nn.Module): return attn_outputs, conv_state, last_state def step(self, hidden_states, conv_state, ssm_state): - _xz = self.in_proj(hidden_states) - _x, _z = _xz.chunk(2, dim=-1) # (B D) - conv_state_new = torch.cat([conv_state, _x.transpose(1, 2)], dim=-1) - conv_out = causal_conv1d_fn( - x=conv_state_new, - weight=self.conv1d.weight.view( - self.conv1d.weight.size(0), self.conv1d.weight.size(2) - ), - bias=self.conv1d.bias, - activation=self.activation, + xz = self.in_proj(hidden_states.squeeze(1)) + x, z = xz.chunk(2, dim=-1) # (B D) + x = causal_conv1d_update(x, conv_state, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation) + x_db = self.x_proj(x) # (B dt_rank+2*d_state) + dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt = F.linear(dt, self.dt_proj.weight) + A = self.negA + y = selective_state_update( + ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True ) - conv_state = conv_state_new[:, :, 1:] - bsz, seqlen, dim = hidden_states.shape - output_tensor = torch.zeros( - (bsz, seqlen, dim), device=hidden_states.device, dtype=hidden_states.dtype - ) - for i in range(0, bsz): - x = conv_out[i : i + 1, :, -1] - z = _z[i : i + 1, -1, :] - x_db = self.x_proj(x) - dt, B, C = torch.split( - x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1 - ) - dt = F.linear(dt, self.dt_proj.weight) - y = selective_state_update( - ssm_state[i : i + 1, :, :], - x, - dt, - self.negA, - B, - C, - self.D, - z=z, - dt_bias=self.dt_proj.bias, - dt_softplus=True, - ) - out = self.out_proj(y) - output_tensor[i] = out - - return output_tensor, conv_state, ssm_state + out = self.out_proj(y) + return out.unsqueeze(1), conv_state.clone(), ssm_state.clone() class ResidualBlock(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix, config, weights, layer_id): super().__init__() self.mamba_block = MambaBlock( - prefix=f"{layer_id}.mixer", config=config, weights=weights + prefix=f"{prefix}.mixer", config=config, weights=weights, layer_id=layer_id ) self.layer_norm = FastRMSNorm.load( - prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon + prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_epsilon ) def forward( @@ -200,7 +181,7 @@ class MambaModel(nn.Module): self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) self.blocks = nn.ModuleList( [ - ResidualBlock(f"{prefix}.layers.{i}", config, weights) + ResidualBlock(f"{prefix}.layers.{i}", config, weights, layer_id=i) for i in range(config.n_layer) ] ) @@ -216,14 +197,12 @@ class MambaModel(nn.Module): self, input_ids: torch.Tensor, inference_params=None, residual=None ) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]: hidden_states = self.embed_tokens(input_ids) - for block in self.blocks: + for i, block in enumerate(self.blocks): hidden_states, residual, conv_state, ssm_state = block( hidden_states, residual, inference_params ) - inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = ( - conv_state, - ssm_state, - ) + inference_params.conv_states[i].copy_(conv_state) + inference_params.ssm_states[i].copy_(ssm_state) hidden_states = ( hidden_states + residual if residual is not None else hidden_states @@ -234,4 +213,4 @@ class MambaModel(nn.Module): # update the offset for the next inference using these params inference_params.seqlen_offset += input_ids.size(1) - return logits, input_ids, inference_params + return logits diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index c7fda516..e04a9719 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -28,12 +28,12 @@ from text_generation_server.models.cache_manager import ( BLOCK_SIZE, ) from text_generation_server.pb import generate_pb2 +from text_generation_server.models.globals import MEM_POOL from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION tracer = trace.get_tracer(__name__) -MEM_POOL = torch.cuda.graph_pool_handle() @dataclass diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py new file mode 100644 index 00000000..b0dca376 --- /dev/null +++ b/server/text_generation_server/models/globals.py @@ -0,0 +1,3 @@ +import torch + +MEM_POOL = torch.cuda.graph_pool_handle() diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index c51e1e20..8f18e475 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -2,17 +2,20 @@ import torch import torch.distributed from transformers import AutoTokenizer, PreTrainedTokenizerBase from typing import Optional +import os from text_generation_server.models.custom_modeling.mamba_modeling import ( MambaConfig, ) +from loguru import logger from text_generation_server.pb import generate_pb2 from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, ) +from text_generation_server.models.globals import MEM_POOL import time -from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel +from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel, InferenceParams from text_generation_server.models import Model from typing import Any, List, Optional, Tuple, Type, Dict from text_generation_server.models.types import ( @@ -24,7 +27,34 @@ from text_generation_server.models.types import ( from text_generation_server.utils.tokens import batch_top_tokens, Sampling from dataclasses import dataclass from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling -from mamba_ssm.utils.generation import InferenceParams + +def new_inference_params(n_blocks: int, batch_size: int, d_inner: int, d_conv: int, d_state: int, seqlen_offset: int, dtype: torch.dtype, device: torch.device): + max_seqlen = 0 + conv_states = torch.zeros( + (n_blocks, + batch_size, + d_inner, + d_conv,), + device=device, + dtype=dtype, + ) + ssm_states = torch.zeros( + (n_blocks, + batch_size, + d_inner, + d_state,), + device=device, + dtype=dtype, + ) + inference_params = InferenceParams( + max_seqlen=max_seqlen, + max_batch_size=batch_size, + seqlen_offset=seqlen_offset, + conv_states=conv_states, + ssm_states=ssm_states, + + ) + return inference_params @dataclass @@ -221,14 +251,8 @@ class MambaBatch(Batch): # TODO # Kept it simple by just updating the state, maybe updating the other CPU values is necessary. - key_value_memory_dict = {} - for i, ( - conv_state, - ssm_state, - ) in self.inference_params.key_value_memory_dict.items(): - key_value_memory_dict[i] = (conv_state[indices], ssm_state[indices]) - self.inference_params.key_value_memory_dict = key_value_memory_dict - + self.inference_params.conv_states = self.inference_params.conv_states[:, indices] + self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices] return self @classmethod @@ -254,9 +278,16 @@ class MambaBatch(Batch): top_n_tokens = [] max_tokens = 0 max_seqlen = 0 - batch_size = 0 seqlen_offset = 0 + (n_blocks, _, d_inner, d_conv) = ( + batches[0].inference_params.conv_states.shape + ) + (_, _, _, d_state) = batches[0].inference_params.ssm_states.shape + dtype = batches[0].inference_params.conv_states.dtype + device = batches[0].inference_params.conv_states.device + inference_params = new_inference_params(n_blocks=n_blocks, batch_size=total_batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=device, dtype=dtype) + # Batch tensors input_ids = None top_n_tokens_tensor = None @@ -303,63 +334,16 @@ class MambaBatch(Batch): max_input_length - batch.max_input_length ) * len(batch) - max_seqlen = max(max_seqlen, batch.inference_params.max_seqlen) - seqlen_offset = max(seqlen_offset, batch.inference_params.seqlen_offset) - batch_size += batch.inference_params.max_batch_size + inference_params.max_seqlen = max(inference_params.max_seqlen, batch.inference_params.max_seqlen) + assert batch.inference_params.seqlen_offset != 0, "Invalid seqlen offset" + inference_params.seqlen_offset = max(inference_params.seqlen_offset, batch.inference_params.seqlen_offset) + + + inference_params.conv_states[:, start_index:end_index] = batch.inference_params.conv_states + inference_params.ssm_states[:, start_index:end_index] = batch.inference_params.ssm_states start_index = end_index - (_, d_model, d_conv) = ( - batches[0].inference_params.key_value_memory_dict[0][0].shape - ) - (_, _, d_state) = batches[0].inference_params.key_value_memory_dict[0][1].shape - n_blocks = len(batches[0].inference_params.key_value_memory_dict) - dtype = batches[0].inference_params.key_value_memory_dict[0][0].dtype - device = batches[0].inference_params.key_value_memory_dict[0][0].device - - key_value_memory_dict = {} - for i in range(n_blocks): - conv_state = torch.zeros( - batch_size, - d_model, - d_conv, - device=device, - dtype=dtype, - ) - ssm_state = torch.zeros( - batch_size, - d_model, - d_state, - device=device, - dtype=dtype, - ) - key_value_memory_dict[i] = (conv_state, ssm_state) - lengths_per_sample = torch.zeros(batch_size, dtype=torch.int32, device=device) - - inference_params = InferenceParams( - max_seqlen=max_seqlen, - max_batch_size=batch_size, - seqlen_offset=seqlen_offset, - key_value_memory_dict=key_value_memory_dict, - lengths_per_sample=lengths_per_sample, - ) - - current_batch = 0 - for batch in batches: - for i in range(n_blocks): - conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i] - batch_size = batch.inference_params.max_batch_size - inference_params.key_value_memory_dict[i][0][ - current_batch : current_batch + batch_size - ] = conv_state - inference_params.key_value_memory_dict[i][1][ - current_batch : current_batch + batch_size - ] = ssm_state - inference_params.lengths_per_sample[ - current_batch : current_batch + batch_size - ] = batch.inference_params.lengths_per_sample - current_batch += batch_size - return cls( batch_id=batches[0].batch_id, requests=requests, @@ -394,9 +378,13 @@ class Mamba(Model): trust_remote_code: bool = False, ): self.process_group, _rank, _world_size = initialize_torch_distributed() + self.cuda_graphs = {} if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.float16 if dtype is None else dtype + # Bf16 is important. In f16 accumulations in the matmul are causing + # differences while the server is under load. + # This is detectable by the integration load test + dtype = torch.bfloat16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") @@ -439,17 +427,93 @@ class Mamba(Model): def warmup(self, batch) -> Optional[int]: # TODO: implement warmup for Mamba if needed + if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True": + if self.speculate is None or self.speculate == 0: + try: + logger.info("Experimental support for Cuda Graphs is enabled") + # Warmup cuda graphs + for bs in [1, 2, 4] + [8 * i for i in range(1, 9)]: + self.cuda_graph_warmup(bs) + except Exception: + logger.exception(f"Decode cuda graph warmup failed") + return None + def cuda_graph_warmup(self, batch_size: int): + input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device) + n_blocks = len(self.model.blocks) + + d_state = self.model.config.d_state + d_conv = self.model.config.d_conv + # Inner takes the expand multiplication + d_inner = self.model.config.d_inner + + # Important seqlen_offset to go through the update mecanism with the state + seqlen_offset = 1 + inference_params = new_inference_params(n_blocks=n_blocks, batch_size=batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=self.device, dtype=self.dtype) + + graph = torch.cuda.CUDAGraph() + + torch.cuda.synchronize() + # Run once outside to warmup + self.model.forward( + input_ids=input_ids, + inference_params=inference_params + ) + torch.cuda.synchronize() + + with torch.cuda.graph(graph, pool=MEM_POOL): + logits = self.model.forward( + input_ids=input_ids, + inference_params=inference_params + ) + torch.cuda.synchronize() + graph_dict = { + "input_ids": input_ids, + "inference_params": inference_params, + "graph": graph, + "logits": logits + } + self.cuda_graphs[batch_size] = graph_dict + def forward( self, input_ids: torch.Tensor, - past: Optional[List[torch.Tensor]] = None, + inference_params: Any ) -> Tuple[torch.Tensor, torch.Tensor]: - return self.model( - input_ids, - past=past, - ) + bs = input_ids.shape[0] + padded_bs = bs + if bs == 3: + padded_bs = 4 + elif 3 < bs <= 8: + padded_bs = 8 + elif bs > 8: + padded_bs = (bs + 7) // 8 * 8 + + # Try to find an associated cuda graph + cuda_graph = self.cuda_graphs.get(padded_bs, None) + is_prefill = inference_params is None or inference_params.seqlen_offset == 0 + + if is_prefill or cuda_graph is None: + return self.model( + input_ids, + inference_params=inference_params, + ) + + # Copy inputs to the static inputs of the cuda graph + # Static inputs are potentially padded + cuda_graph["input_ids"][: bs] = input_ids + cuda_graph["inference_params"].conv_states[:, : bs] = inference_params.conv_states + cuda_graph["inference_params"].ssm_states[:, : bs] = inference_params.ssm_states + + # Replay the graph + cuda_graph["graph"].replay() + + inference_params.conv_states.copy_(cuda_graph["inference_params"].conv_states[:, :bs]) + inference_params.ssm_states.copy_(cuda_graph["inference_params"].ssm_states[:, :bs]) + + # Slice output to the correct shape + return cuda_graph["logits"][:bs] def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]: start = time.time_ns() @@ -457,56 +521,26 @@ class Mamba(Model): batch.input_ids ) # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids - batch_size = input_ids.shape[0] - max_seqlen = input_ids.shape[1] - dtype = input_ids.dtype - + batch_size, max_seqlen = input_ids.shape # Inference params - seqlen_og = 0 - inf_cache = {} - lengths_per_sample = ( - torch.ones(batch_size, dtype=torch.int32, device=input_ids.device) - * max_seqlen - ) if batch.inference_params is None: - inference_params = InferenceParams( - max_seqlen=max_seqlen, - max_batch_size=batch_size, - seqlen_offset=seqlen_og, - key_value_memory_dict=inf_cache, - lengths_per_sample=lengths_per_sample, - ) - - # Allocate inference cache - for res_block in self.model.blocks: - block = res_block.mamba_block - conv_state = torch.zeros( - batch_size, - self.model.config.d_model * self.model.config.expand, - self.model.config.d_conv, - device=block.conv1d.weight.device, - dtype=block.conv1d.weight.dtype, - ) - ssm_state = torch.zeros( - batch_size, - self.model.config.d_model * self.model.config.expand, - self.model.config.d_state, - device=block.dt_proj.weight.device, - dtype=block.dt_proj.weight.dtype, - ) - inference_params.key_value_memory_dict[block.layer_idx] = ( - conv_state, - ssm_state, - ) + # 0 is important here + seqlen_offset = 0 + n_blocks = len(self.model.blocks) + d_state = self.model.config.d_state + d_conv = self.model.config.d_conv + d_inner = self.model.config.d_inner + inference_params = new_inference_params(n_blocks=n_blocks, batch_size=batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=self.device, dtype=self.dtype) batch.inference_params = inference_params # Forward pass - logits, past_input_ids, new_inference_params = self.model( - input_ids, batch.inference_params + logits = self.forward( + input_ids, inference_params=batch.inference_params ) - batch.inference_params = new_inference_params + + # batch.inference_params = new_inference_params # Results generations: List[Generation] = [] stopped = True