mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 06:42:10 +00:00
feat: use fused kernel in forward pass
This commit is contained in:
parent
c2681b2bea
commit
966f3ba35c
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from typing import Optional, List, Tuple, Any
|
from typing import Optional, List, Tuple, Any
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
@ -12,12 +13,6 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# note torch must be imported before the custom cuda modules
|
|
||||||
# since they rely on torch's libc10.so
|
|
||||||
import causal_conv1d_cuda
|
|
||||||
import selective_scan_cuda
|
|
||||||
|
|
||||||
class MambaConfig(PretrainedConfig):
|
class MambaConfig(PretrainedConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -50,155 +45,56 @@ class MambaConfig(PretrainedConfig):
|
|||||||
class MambaBlock(nn.Module):
|
class MambaBlock(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# TODO: use model config to set the dt_rank instead of hardcoding it
|
|
||||||
self.dt_rank = (config.d_model + 15) // 16
|
self.dt_rank = (config.d_model + 15) // 16
|
||||||
|
self.x_proj_weight = weights.get_tensor(f"{prefix}.x_proj.weight")
|
||||||
# TODO: improve how we load the conv1d weights
|
self.dt_proj_weight = weights.get_tensor(f"{prefix}.dt_proj.weight")
|
||||||
# explore a transposed conv1d that avoids the need for
|
self.dt_proj_bias = weights.get_tensor(f"{prefix}.dt_proj.bias")
|
||||||
# a transpose during inference
|
self.out_proj_weight = weights.get_tensor(f"{prefix}.out_proj.weight")
|
||||||
self.conv1 = nn.Conv1d(
|
self.out_proj_bias = None
|
||||||
config.d_inner,
|
# TODO: avoid loading the same weights twice
|
||||||
config.d_inner,
|
self.in_proj_weight = weights.get_tensor(f"{prefix}.in_proj.weight")
|
||||||
kernel_size=config.d_conv,
|
self.in_proj_bias = None
|
||||||
groups=config.d_inner,
|
|
||||||
padding=config.d_conv - 1,
|
|
||||||
)
|
|
||||||
self.conv1.weight = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.weight"))
|
|
||||||
self.conv1.bias = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.bias"))
|
|
||||||
|
|
||||||
self.dt_proj = TensorParallelColumnLinear.load(
|
|
||||||
config=config,
|
|
||||||
prefix=f"{prefix}.dt_proj",
|
|
||||||
weights=weights,
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
self.in_proj = TensorParallelColumnLinear.load(
|
self.in_proj = TensorParallelColumnLinear.load(
|
||||||
config=config,
|
config=config,
|
||||||
prefix=f"{prefix}.in_proj",
|
prefix=f"{prefix}.in_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
self.x_proj = TensorParallelColumnLinear.load(
|
self.conv1d = nn.Conv1d(
|
||||||
config=config,
|
config.d_inner,
|
||||||
prefix=f"{prefix}.x_proj",
|
config.d_inner,
|
||||||
weights=weights,
|
kernel_size=config.d_conv,
|
||||||
bias=False,
|
groups=config.d_inner,
|
||||||
|
padding=config.d_conv - 1,
|
||||||
)
|
)
|
||||||
self.out_proj = TensorParallelColumnLinear.load(
|
self.conv1d.weight = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.weight"))
|
||||||
config=config,
|
self.conv1d.bias = nn.Parameter(weights.get_tensor(f"{prefix}.conv1d.bias"))
|
||||||
prefix=f"{prefix}.out_proj",
|
|
||||||
weights=weights,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: improve how we load the weights
|
|
||||||
self.A_log = nn.Parameter(weights.get_tensor(f"{prefix}.A_log"))
|
self.A_log = nn.Parameter(weights.get_tensor(f"{prefix}.A_log"))
|
||||||
self.D = nn.Parameter(weights.get_tensor(f"{prefix}.D"))
|
self.D = nn.Parameter(weights.get_tensor(f"{prefix}.D"))
|
||||||
|
|
||||||
def selective_scan(
|
|
||||||
self, input_tensor, delta, a_tensor, b_tensor, c_tensor, d_tensor
|
|
||||||
):
|
|
||||||
batch_size, sequence_length, input_dim = input_tensor.shape
|
|
||||||
num_cols = a_tensor.shape[1]
|
|
||||||
|
|
||||||
# TODO: revisit this math to avoid the transposes when possible
|
|
||||||
# reshape and process delta
|
|
||||||
delta = delta.transpose(1, 2).view((batch_size, input_dim, sequence_length, 1))
|
|
||||||
exp_delta_a = (delta * a_tensor.view((1, input_dim, 1, num_cols))).exp()
|
|
||||||
|
|
||||||
# calc involving delta, b_tensor, and input_tensor
|
|
||||||
delta_b_input = (
|
|
||||||
delta
|
|
||||||
* b_tensor.view((batch_size, 1, sequence_length, num_cols))
|
|
||||||
* input_tensor.transpose(1, 2).view(
|
|
||||||
(batch_size, input_dim, sequence_length, 1)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# init output tensor
|
|
||||||
output_tensor = torch.zeros(
|
|
||||||
(batch_size, input_dim, num_cols),
|
|
||||||
dtype=exp_delta_a.dtype,
|
|
||||||
device=exp_delta_a.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# iterate over sequence_length
|
|
||||||
output_sequence = []
|
|
||||||
for i in range(sequence_length):
|
|
||||||
multiplier = exp_delta_a[:, :, i]
|
|
||||||
output_tensor = (multiplier * output_tensor) + delta_b_input[:, :, i]
|
|
||||||
y = output_tensor.matmul(c_tensor[:, i, :].unsqueeze(2)).squeeze(2)
|
|
||||||
output_sequence.append(y)
|
|
||||||
|
|
||||||
stacked_output = torch.stack(output_sequence, 1)
|
|
||||||
return stacked_output + input_tensor * d_tensor
|
|
||||||
|
|
||||||
def ssm(self, hidden_states):
|
|
||||||
_input_dim, num_cols = self.A_log.shape
|
|
||||||
negative_exponential_a = self.A_log.exp().neg()
|
|
||||||
d_matrix = self.D
|
|
||||||
projected_hidden_states = self.x_proj(hidden_states)
|
|
||||||
|
|
||||||
# narrow operations for delta, b, and c
|
|
||||||
delta = projected_hidden_states.narrow(-1, 0, self.dt_rank)
|
|
||||||
b_matrix = projected_hidden_states.narrow(-1, self.dt_rank, num_cols)
|
|
||||||
c_matrix = projected_hidden_states.narrow(-1, self.dt_rank + num_cols, num_cols)
|
|
||||||
|
|
||||||
# process delta
|
|
||||||
delta = self.dt_proj(delta)
|
|
||||||
delta = torch.log(torch.exp(delta) + 1)
|
|
||||||
|
|
||||||
# apply selective scan
|
|
||||||
selective_scan_output = self.selective_scan(
|
|
||||||
hidden_states, delta, negative_exponential_a, b_matrix, c_matrix, d_matrix
|
|
||||||
)
|
|
||||||
return selective_scan_output
|
|
||||||
|
|
||||||
def forward(self, index, hidden_states, past_transformed_state):
|
def forward(self, index, hidden_states, past_transformed_state):
|
||||||
sequence_length = hidden_states.shape[1]
|
|
||||||
|
|
||||||
# minimal amount of new work on single hidden state (previous hidden state are cached)
|
|
||||||
only_last = hidden_states[:, -1, :]
|
|
||||||
projected_only_last = self.in_proj(only_last)
|
|
||||||
transformed_only_last, residual_only_last = torch.chunk(
|
|
||||||
projected_only_last, 2, dim=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
if past_transformed_state is not None:
|
|
||||||
# build a new transformed_states tensor with past_transformed_state and transformed_only_last
|
|
||||||
new_transformed_states = torch.cat(
|
|
||||||
[past_transformed_state, transformed_only_last.unsqueeze(1)], dim=1
|
|
||||||
)
|
|
||||||
transformed_states = new_transformed_states
|
|
||||||
residual_states = residual_only_last
|
|
||||||
else:
|
|
||||||
# prefilling the cache with the last transformed state
|
|
||||||
projected_states = self.in_proj(hidden_states)
|
projected_states = self.in_proj(hidden_states)
|
||||||
split_states = torch.chunk(projected_states, 2, dim=-1)
|
|
||||||
transformed_states, residual_states = split_states
|
|
||||||
|
|
||||||
# NOTE: we need the past hidden states to produce the correct output
|
A = -torch.exp(self.A_log.float())
|
||||||
# therefore we cannot simply compute the most recent and append it as we
|
|
||||||
# did for the transformed states
|
|
||||||
|
|
||||||
# TODO: avoid the transpose by using a transposed conv1d
|
# conv1d, ssm, and selective_scan are all fused into one kernel
|
||||||
# apply convolution and narrowing operation
|
attn_outputs = mamba_inner_fn(
|
||||||
conv_output = (
|
projected_states.transpose(1,2),
|
||||||
self.conv1(transformed_states.transpose(1, 2))
|
self.conv1d.weight,
|
||||||
.narrow(-1, 0, sequence_length)
|
self.conv1d.bias,
|
||||||
.transpose(1, 2)
|
self.x_proj_weight,
|
||||||
|
self.dt_proj_weight,
|
||||||
|
self.out_proj_weight,
|
||||||
|
self.out_proj_bias,
|
||||||
|
A,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
self.D.float(),
|
||||||
|
delta_bias=self.dt_proj_bias.float(),
|
||||||
|
delta_softplus=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# apply silu (Swish) activation function
|
return attn_outputs, projected_states
|
||||||
activated_transformed = F.silu(conv_output)
|
|
||||||
activated_residual = F.silu(residual_states)
|
|
||||||
|
|
||||||
# Subsequent operations
|
|
||||||
output = self.ssm(activated_transformed)
|
|
||||||
combined_output = output * activated_residual
|
|
||||||
|
|
||||||
return self.out_proj(combined_output), transformed_states
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: prefer a more optimized implementation of RMSNorm if possible
|
# TODO: prefer a more optimized implementation of RMSNorm if possible
|
||||||
|
@ -256,10 +256,10 @@ class Mamba(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
next_token_tensor = next_token_id_squeezed.view(1, 1)
|
||||||
# Update values
|
# Update values
|
||||||
batch.input_ids = torch.cat(
|
batch.input_ids = torch.cat(
|
||||||
[batch.input_ids, torch.tensor([[next_token_id_squeezed]])], dim=1
|
[batch.input_ids, next_token_tensor], dim=1
|
||||||
)
|
)
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
batch.input_lengths[i] = new_input_length
|
batch.input_lengths[i] = new_input_length
|
||||||
|
Loading…
Reference in New Issue
Block a user