mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 06:42:10 +00:00
feat: use fused kernel in forward pass
This commit is contained in:
parent
c2681b2bea
commit
966f3ba35c
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
||||
from torch import nn
|
||||
from typing import Optional, List, Tuple, Any
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
@ -12,12 +13,6 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelEmbedding,
|
||||
)
|
||||
|
||||
|
||||
# note torch must be imported before the custom cuda modules
|
||||
# since they rely on torch's libc10.so
|
||||
import causal_conv1d_cuda
|
||||
import selective_scan_cuda
|
||||
|
||||
class MambaConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
@ -50,155 +45,56 @@ class MambaConfig(PretrainedConfig):
|
||||
class MambaBlock(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
# TODO: use model config to set the dt_rank instead of hardcoding it
|
||||
self.dt_rank = (config.d_model + 15) // 16
|
||||
|
||||
# TODO: improve how we load the conv1d weights
|
||||
# explore a transposed conv1d that avoids the need for
|
||||
# a transpose during inference
|
||||
self.conv1 = nn.Conv1d(
|
||||
config.d_inner,
|
||||
config.d_inner,
|
||||
kernel_size=config.d_conv,
|
||||
groups=config.d_inner,
|
||||
padding=config.d_conv - 1,
|
||||
)
|
||||
self.conv1.weight = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.weight"))
|
||||
self.conv1.bias = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.bias"))
|
||||
|
||||
self.dt_proj = TensorParallelColumnLinear.load(
|
||||
config=config,
|
||||
prefix=f"{prefix}.dt_proj",
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
self.x_proj_weight = weights.get_tensor(f"{prefix}.x_proj.weight")
|
||||
self.dt_proj_weight = weights.get_tensor(f"{prefix}.dt_proj.weight")
|
||||
self.dt_proj_bias = weights.get_tensor(f"{prefix}.dt_proj.bias")
|
||||
self.out_proj_weight = weights.get_tensor(f"{prefix}.out_proj.weight")
|
||||
self.out_proj_bias = None
|
||||
# TODO: avoid loading the same weights twice
|
||||
self.in_proj_weight = weights.get_tensor(f"{prefix}.in_proj.weight")
|
||||
self.in_proj_bias = None
|
||||
self.in_proj = TensorParallelColumnLinear.load(
|
||||
config=config,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.x_proj = TensorParallelColumnLinear.load(
|
||||
config=config,
|
||||
prefix=f"{prefix}.x_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
self.conv1d = nn.Conv1d(
|
||||
config.d_inner,
|
||||
config.d_inner,
|
||||
kernel_size=config.d_conv,
|
||||
groups=config.d_inner,
|
||||
padding=config.d_conv - 1,
|
||||
)
|
||||
self.out_proj = TensorParallelColumnLinear.load(
|
||||
config=config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# TODO: improve how we load the weights
|
||||
self.conv1d.weight = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.weight"))
|
||||
self.conv1d.bias = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.bias"))
|
||||
self.A_log = nn.Parameter(weights.get_tensor(f"{prefix}.A_log"))
|
||||
self.D = nn.Parameter(weights.get_tensor(f"{prefix}.D"))
|
||||
|
||||
def selective_scan(
|
||||
self, input_tensor, delta, a_tensor, b_tensor, c_tensor, d_tensor
|
||||
):
|
||||
batch_size, sequence_length, input_dim = input_tensor.shape
|
||||
num_cols = a_tensor.shape[1]
|
||||
|
||||
# TODO: revisit this math to avoid the transposes when possible
|
||||
# reshape and process delta
|
||||
delta = delta.transpose(1, 2).view((batch_size, input_dim, sequence_length, 1))
|
||||
exp_delta_a = (delta * a_tensor.view((1, input_dim, 1, num_cols))).exp()
|
||||
|
||||
# calc involving delta, b_tensor, and input_tensor
|
||||
delta_b_input = (
|
||||
delta
|
||||
* b_tensor.view((batch_size, 1, sequence_length, num_cols))
|
||||
* input_tensor.transpose(1, 2).view(
|
||||
(batch_size, input_dim, sequence_length, 1)
|
||||
)
|
||||
)
|
||||
|
||||
# init output tensor
|
||||
output_tensor = torch.zeros(
|
||||
(batch_size, input_dim, num_cols),
|
||||
dtype=exp_delta_a.dtype,
|
||||
device=exp_delta_a.device,
|
||||
)
|
||||
|
||||
# iterate over sequence_length
|
||||
output_sequence = []
|
||||
for i in range(sequence_length):
|
||||
multiplier = exp_delta_a[:, :, i]
|
||||
output_tensor = (multiplier * output_tensor) + delta_b_input[:, :, i]
|
||||
y = output_tensor.matmul(c_tensor[:, i, :].unsqueeze(2)).squeeze(2)
|
||||
output_sequence.append(y)
|
||||
|
||||
stacked_output = torch.stack(output_sequence, 1)
|
||||
return stacked_output + input_tensor * d_tensor
|
||||
|
||||
def ssm(self, hidden_states):
|
||||
_input_dim, num_cols = self.A_log.shape
|
||||
negative_exponential_a = self.A_log.exp().neg()
|
||||
d_matrix = self.D
|
||||
projected_hidden_states = self.x_proj(hidden_states)
|
||||
|
||||
# narrow operations for delta, b, and c
|
||||
delta = projected_hidden_states.narrow(-1, 0, self.dt_rank)
|
||||
b_matrix = projected_hidden_states.narrow(-1, self.dt_rank, num_cols)
|
||||
c_matrix = projected_hidden_states.narrow(-1, self.dt_rank + num_cols, num_cols)
|
||||
|
||||
# process delta
|
||||
delta = self.dt_proj(delta)
|
||||
delta = torch.log(torch.exp(delta) + 1)
|
||||
|
||||
# apply selective scan
|
||||
selective_scan_output = self.selective_scan(
|
||||
hidden_states, delta, negative_exponential_a, b_matrix, c_matrix, d_matrix
|
||||
)
|
||||
return selective_scan_output
|
||||
|
||||
def forward(self, index, hidden_states, past_transformed_state):
|
||||
sequence_length = hidden_states.shape[1]
|
||||
|
||||
# minimal amount of new work on single hidden state (previous hidden state are cached)
|
||||
only_last = hidden_states[:, -1, :]
|
||||
projected_only_last = self.in_proj(only_last)
|
||||
transformed_only_last, residual_only_last = torch.chunk(
|
||||
projected_only_last, 2, dim=-1
|
||||
)
|
||||
|
||||
if past_transformed_state is not None:
|
||||
# build a new transformed_states tensor with past_transformed_state and transformed_only_last
|
||||
new_transformed_states = torch.cat(
|
||||
[past_transformed_state, transformed_only_last.unsqueeze(1)], dim=1
|
||||
)
|
||||
transformed_states = new_transformed_states
|
||||
residual_states = residual_only_last
|
||||
else:
|
||||
# prefilling the cache with the last transformed state
|
||||
projected_states = self.in_proj(hidden_states)
|
||||
split_states = torch.chunk(projected_states, 2, dim=-1)
|
||||
transformed_states, residual_states = split_states
|
||||
|
||||
# NOTE: we need the past hidden states to produce the correct output
|
||||
# therefore we cannot simply compute the most recent and append it as we
|
||||
# did for the transformed states
|
||||
A = -torch.exp(self.A_log.float())
|
||||
|
||||
# TODO: avoid the transpose by using a transposed conv1d
|
||||
# apply convolution and narrowing operation
|
||||
conv_output = (
|
||||
self.conv1(transformed_states.transpose(1, 2))
|
||||
.narrow(-1, 0, sequence_length)
|
||||
.transpose(1, 2)
|
||||
# conv1d, ssm, and selective_scan are all fused into one kernel
|
||||
attn_outputs = mamba_inner_fn(
|
||||
projected_states.transpose(1,2),
|
||||
self.conv1d.weight,
|
||||
self.conv1d.bias,
|
||||
self.x_proj_weight,
|
||||
self.dt_proj_weight,
|
||||
self.out_proj_weight,
|
||||
self.out_proj_bias,
|
||||
A,
|
||||
None,
|
||||
None,
|
||||
self.D.float(),
|
||||
delta_bias=self.dt_proj_bias.float(),
|
||||
delta_softplus=True,
|
||||
)
|
||||
|
||||
# apply silu (Swish) activation function
|
||||
activated_transformed = F.silu(conv_output)
|
||||
activated_residual = F.silu(residual_states)
|
||||
|
||||
# Subsequent operations
|
||||
output = self.ssm(activated_transformed)
|
||||
combined_output = output * activated_residual
|
||||
|
||||
return self.out_proj(combined_output), transformed_states
|
||||
return attn_outputs, projected_states
|
||||
|
||||
|
||||
# TODO: prefer a more optimized implementation of RMSNorm if possible
|
||||
|
@ -256,10 +256,10 @@ class Mamba(Model):
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
|
||||
next_token_tensor = next_token_id_squeezed.view(1, 1)
|
||||
# Update values
|
||||
batch.input_ids = torch.cat(
|
||||
[batch.input_ids, torch.tensor([[next_token_id_squeezed]])], dim=1
|
||||
[batch.input_ids, next_token_tensor], dim=1
|
||||
)
|
||||
batch.all_input_ids[i] = all_input_ids
|
||||
batch.input_lengths[i] = new_input_length
|
||||
|
Loading…
Reference in New Issue
Block a user