diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py index 7fb203447..8df5f4df7 100644 --- a/server/text_generation_server/models/custom_modeling/mamba_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -1,6 +1,7 @@ import torch import torch.distributed +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn from torch import nn from typing import Optional, List, Tuple, Any from transformers.configuration_utils import PretrainedConfig @@ -12,12 +13,6 @@ from text_generation_server.utils.layers import ( TensorParallelEmbedding, ) - -# note torch must be imported before the custom cuda modules -# since they rely on torch's libc10.so -import causal_conv1d_cuda -import selective_scan_cuda - class MambaConfig(PretrainedConfig): def __init__( self, @@ -50,155 +45,56 @@ class MambaConfig(PretrainedConfig): class MambaBlock(nn.Module): def __init__(self, prefix, config, weights): super().__init__() - - # TODO: use model config to set the dt_rank instead of hardcoding it self.dt_rank = (config.d_model + 15) // 16 - - # TODO: improve how we load the conv1d weights - # explore a transposed conv1d that avoids the need for - # a transpose during inference - self.conv1 = nn.Conv1d( - config.d_inner, - config.d_inner, - kernel_size=config.d_conv, - groups=config.d_inner, - padding=config.d_conv - 1, - ) - self.conv1.weight = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.weight")) - self.conv1.bias = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.bias")) - - self.dt_proj = TensorParallelColumnLinear.load( - config=config, - prefix=f"{prefix}.dt_proj", - weights=weights, - bias=True, - ) + self.x_proj_weight = weights.get_tensor(f"{prefix}.x_proj.weight") + self.dt_proj_weight = weights.get_tensor(f"{prefix}.dt_proj.weight") + self.dt_proj_bias = weights.get_tensor(f"{prefix}.dt_proj.bias") + self.out_proj_weight = weights.get_tensor(f"{prefix}.out_proj.weight") + self.out_proj_bias = None + # TODO: avoid loading the same weights twice + self.in_proj_weight = weights.get_tensor(f"{prefix}.in_proj.weight") + self.in_proj_bias = None self.in_proj = TensorParallelColumnLinear.load( config=config, prefix=f"{prefix}.in_proj", weights=weights, bias=False, ) - self.x_proj = TensorParallelColumnLinear.load( - config=config, - prefix=f"{prefix}.x_proj", - weights=weights, - bias=False, + self.conv1d = nn.Conv1d( + config.d_inner, + config.d_inner, + kernel_size=config.d_conv, + groups=config.d_inner, + padding=config.d_conv - 1, ) - self.out_proj = TensorParallelColumnLinear.load( - config=config, - prefix=f"{prefix}.out_proj", - weights=weights, - bias=False, - ) - - # TODO: improve how we load the weights + self.conv1d.weight = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.weight")) + self.conv1d.bias = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.bias")) self.A_log = nn.Parameter(weights.get_tensor(f"{prefix}.A_log")) self.D = nn.Parameter(weights.get_tensor(f"{prefix}.D")) - def selective_scan( - self, input_tensor, delta, a_tensor, b_tensor, c_tensor, d_tensor - ): - batch_size, sequence_length, input_dim = input_tensor.shape - num_cols = a_tensor.shape[1] - - # TODO: revisit this math to avoid the transposes when possible - # reshape and process delta - delta = delta.transpose(1, 2).view((batch_size, input_dim, sequence_length, 1)) - exp_delta_a = (delta * a_tensor.view((1, input_dim, 1, num_cols))).exp() - - # calc involving delta, b_tensor, and input_tensor - delta_b_input = ( - delta - * b_tensor.view((batch_size, 1, sequence_length, num_cols)) - * input_tensor.transpose(1, 2).view( - (batch_size, input_dim, sequence_length, 1) - ) - ) - - # init output tensor - output_tensor = torch.zeros( - (batch_size, input_dim, num_cols), - dtype=exp_delta_a.dtype, - device=exp_delta_a.device, - ) - - # iterate over sequence_length - output_sequence = [] - for i in range(sequence_length): - multiplier = exp_delta_a[:, :, i] - output_tensor = (multiplier * output_tensor) + delta_b_input[:, :, i] - y = output_tensor.matmul(c_tensor[:, i, :].unsqueeze(2)).squeeze(2) - output_sequence.append(y) - - stacked_output = torch.stack(output_sequence, 1) - return stacked_output + input_tensor * d_tensor - - def ssm(self, hidden_states): - _input_dim, num_cols = self.A_log.shape - negative_exponential_a = self.A_log.exp().neg() - d_matrix = self.D - projected_hidden_states = self.x_proj(hidden_states) - - # narrow operations for delta, b, and c - delta = projected_hidden_states.narrow(-1, 0, self.dt_rank) - b_matrix = projected_hidden_states.narrow(-1, self.dt_rank, num_cols) - c_matrix = projected_hidden_states.narrow(-1, self.dt_rank + num_cols, num_cols) - - # process delta - delta = self.dt_proj(delta) - delta = torch.log(torch.exp(delta) + 1) - - # apply selective scan - selective_scan_output = self.selective_scan( - hidden_states, delta, negative_exponential_a, b_matrix, c_matrix, d_matrix - ) - return selective_scan_output - def forward(self, index, hidden_states, past_transformed_state): - sequence_length = hidden_states.shape[1] + projected_states = self.in_proj(hidden_states) - # minimal amount of new work on single hidden state (previous hidden state are cached) - only_last = hidden_states[:, -1, :] - projected_only_last = self.in_proj(only_last) - transformed_only_last, residual_only_last = torch.chunk( - projected_only_last, 2, dim=-1 + A = -torch.exp(self.A_log.float()) + + # conv1d, ssm, and selective_scan are all fused into one kernel + attn_outputs = mamba_inner_fn( + projected_states.transpose(1,2), + self.conv1d.weight, + self.conv1d.bias, + self.x_proj_weight, + self.dt_proj_weight, + self.out_proj_weight, + self.out_proj_bias, + A, + None, + None, + self.D.float(), + delta_bias=self.dt_proj_bias.float(), + delta_softplus=True, ) - - if past_transformed_state is not None: - # build a new transformed_states tensor with past_transformed_state and transformed_only_last - new_transformed_states = torch.cat( - [past_transformed_state, transformed_only_last.unsqueeze(1)], dim=1 - ) - transformed_states = new_transformed_states - residual_states = residual_only_last - else: - # prefilling the cache with the last transformed state - projected_states = self.in_proj(hidden_states) - split_states = torch.chunk(projected_states, 2, dim=-1) - transformed_states, residual_states = split_states - - # NOTE: we need the past hidden states to produce the correct output - # therefore we cannot simply compute the most recent and append it as we - # did for the transformed states - - # TODO: avoid the transpose by using a transposed conv1d - # apply convolution and narrowing operation - conv_output = ( - self.conv1(transformed_states.transpose(1, 2)) - .narrow(-1, 0, sequence_length) - .transpose(1, 2) - ) - - # apply silu (Swish) activation function - activated_transformed = F.silu(conv_output) - activated_residual = F.silu(residual_states) - - # Subsequent operations - output = self.ssm(activated_transformed) - combined_output = output * activated_residual - - return self.out_proj(combined_output), transformed_states + + return attn_outputs, projected_states # TODO: prefer a more optimized implementation of RMSNorm if possible diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 68c2bc9c2..8b8904260 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -256,10 +256,10 @@ class Mamba(Model): ) generations.append(generation) - + next_token_tensor = next_token_id_squeezed.view(1, 1) # Update values batch.input_ids = torch.cat( - [batch.input_ids, torch.tensor([[next_token_id_squeezed]])], dim=1 + [batch.input_ids, next_token_tensor], dim=1 ) batch.all_input_ids[i] = all_input_ids batch.input_lengths[i] = new_input_length