feat: use fused kernel in forward pass

This commit is contained in:
drbh 2024-01-29 21:54:23 +00:00
parent c2681b2bea
commit 966f3ba35c
2 changed files with 39 additions and 143 deletions

View File

@ -1,6 +1,7 @@
import torch
import torch.distributed
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
from torch import nn
from typing import Optional, List, Tuple, Any
from transformers.configuration_utils import PretrainedConfig
@ -12,12 +13,6 @@ from text_generation_server.utils.layers import (
TensorParallelEmbedding,
)
# note torch must be imported before the custom cuda modules
# since they rely on torch's libc10.so
import causal_conv1d_cuda
import selective_scan_cuda
class MambaConfig(PretrainedConfig):
def __init__(
self,
@ -50,155 +45,56 @@ class MambaConfig(PretrainedConfig):
class MambaBlock(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
# TODO: use model config to set the dt_rank instead of hardcoding it
self.dt_rank = (config.d_model + 15) // 16
# TODO: improve how we load the conv1d weights
# explore a transposed conv1d that avoids the need for
# a transpose during inference
self.conv1 = nn.Conv1d(
config.d_inner,
config.d_inner,
kernel_size=config.d_conv,
groups=config.d_inner,
padding=config.d_conv - 1,
)
self.conv1.weight = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.weight"))
self.conv1.bias = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.bias"))
self.dt_proj = TensorParallelColumnLinear.load(
config=config,
prefix=f"{prefix}.dt_proj",
weights=weights,
bias=True,
)
self.x_proj_weight = weights.get_tensor(f"{prefix}.x_proj.weight")
self.dt_proj_weight = weights.get_tensor(f"{prefix}.dt_proj.weight")
self.dt_proj_bias = weights.get_tensor(f"{prefix}.dt_proj.bias")
self.out_proj_weight = weights.get_tensor(f"{prefix}.out_proj.weight")
self.out_proj_bias = None
# TODO: avoid loading the same weights twice
self.in_proj_weight = weights.get_tensor(f"{prefix}.in_proj.weight")
self.in_proj_bias = None
self.in_proj = TensorParallelColumnLinear.load(
config=config,
prefix=f"{prefix}.in_proj",
weights=weights,
bias=False,
)
self.x_proj = TensorParallelColumnLinear.load(
config=config,
prefix=f"{prefix}.x_proj",
weights=weights,
bias=False,
self.conv1d = nn.Conv1d(
config.d_inner,
config.d_inner,
kernel_size=config.d_conv,
groups=config.d_inner,
padding=config.d_conv - 1,
)
self.out_proj = TensorParallelColumnLinear.load(
config=config,
prefix=f"{prefix}.out_proj",
weights=weights,
bias=False,
)
# TODO: improve how we load the weights
self.conv1d.weight = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.weight"))
self.conv1d.bias = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.bias"))
self.A_log = nn.Parameter(weights.get_tensor(f"{prefix}.A_log"))
self.D = nn.Parameter(weights.get_tensor(f"{prefix}.D"))
def selective_scan(
self, input_tensor, delta, a_tensor, b_tensor, c_tensor, d_tensor
):
batch_size, sequence_length, input_dim = input_tensor.shape
num_cols = a_tensor.shape[1]
# TODO: revisit this math to avoid the transposes when possible
# reshape and process delta
delta = delta.transpose(1, 2).view((batch_size, input_dim, sequence_length, 1))
exp_delta_a = (delta * a_tensor.view((1, input_dim, 1, num_cols))).exp()
# calc involving delta, b_tensor, and input_tensor
delta_b_input = (
delta
* b_tensor.view((batch_size, 1, sequence_length, num_cols))
* input_tensor.transpose(1, 2).view(
(batch_size, input_dim, sequence_length, 1)
)
)
# init output tensor
output_tensor = torch.zeros(
(batch_size, input_dim, num_cols),
dtype=exp_delta_a.dtype,
device=exp_delta_a.device,
)
# iterate over sequence_length
output_sequence = []
for i in range(sequence_length):
multiplier = exp_delta_a[:, :, i]
output_tensor = (multiplier * output_tensor) + delta_b_input[:, :, i]
y = output_tensor.matmul(c_tensor[:, i, :].unsqueeze(2)).squeeze(2)
output_sequence.append(y)
stacked_output = torch.stack(output_sequence, 1)
return stacked_output + input_tensor * d_tensor
def ssm(self, hidden_states):
_input_dim, num_cols = self.A_log.shape
negative_exponential_a = self.A_log.exp().neg()
d_matrix = self.D
projected_hidden_states = self.x_proj(hidden_states)
# narrow operations for delta, b, and c
delta = projected_hidden_states.narrow(-1, 0, self.dt_rank)
b_matrix = projected_hidden_states.narrow(-1, self.dt_rank, num_cols)
c_matrix = projected_hidden_states.narrow(-1, self.dt_rank + num_cols, num_cols)
# process delta
delta = self.dt_proj(delta)
delta = torch.log(torch.exp(delta) + 1)
# apply selective scan
selective_scan_output = self.selective_scan(
hidden_states, delta, negative_exponential_a, b_matrix, c_matrix, d_matrix
)
return selective_scan_output
def forward(self, index, hidden_states, past_transformed_state):
sequence_length = hidden_states.shape[1]
# minimal amount of new work on single hidden state (previous hidden state are cached)
only_last = hidden_states[:, -1, :]
projected_only_last = self.in_proj(only_last)
transformed_only_last, residual_only_last = torch.chunk(
projected_only_last, 2, dim=-1
)
if past_transformed_state is not None:
# build a new transformed_states tensor with past_transformed_state and transformed_only_last
new_transformed_states = torch.cat(
[past_transformed_state, transformed_only_last.unsqueeze(1)], dim=1
)
transformed_states = new_transformed_states
residual_states = residual_only_last
else:
# prefilling the cache with the last transformed state
projected_states = self.in_proj(hidden_states)
split_states = torch.chunk(projected_states, 2, dim=-1)
transformed_states, residual_states = split_states
# NOTE: we need the past hidden states to produce the correct output
# therefore we cannot simply compute the most recent and append it as we
# did for the transformed states
A = -torch.exp(self.A_log.float())
# TODO: avoid the transpose by using a transposed conv1d
# apply convolution and narrowing operation
conv_output = (
self.conv1(transformed_states.transpose(1, 2))
.narrow(-1, 0, sequence_length)
.transpose(1, 2)
# conv1d, ssm, and selective_scan are all fused into one kernel
attn_outputs = mamba_inner_fn(
projected_states.transpose(1,2),
self.conv1d.weight,
self.conv1d.bias,
self.x_proj_weight,
self.dt_proj_weight,
self.out_proj_weight,
self.out_proj_bias,
A,
None,
None,
self.D.float(),
delta_bias=self.dt_proj_bias.float(),
delta_softplus=True,
)
# apply silu (Swish) activation function
activated_transformed = F.silu(conv_output)
activated_residual = F.silu(residual_states)
# Subsequent operations
output = self.ssm(activated_transformed)
combined_output = output * activated_residual
return self.out_proj(combined_output), transformed_states
return attn_outputs, projected_states
# TODO: prefer a more optimized implementation of RMSNorm if possible

View File

@ -256,10 +256,10 @@ class Mamba(Model):
)
generations.append(generation)
next_token_tensor = next_token_id_squeezed.view(1, 1)
# Update values
batch.input_ids = torch.cat(
[batch.input_ids, torch.tensor([[next_token_id_squeezed]])], dim=1
[batch.input_ids, next_token_tensor], dim=1
)
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length