mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: prefer triton ops and batch conv
This commit is contained in:
parent
8319e854c8
commit
5e102183d8
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
||||||
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
|
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
|
||||||
from mamba_ssm.utils.generation import InferenceParams
|
from mamba_ssm.utils.generation import InferenceParams
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -14,7 +15,7 @@ from text_generation_server.utils.layers import (
|
|||||||
FastLinear,
|
FastLinear,
|
||||||
)
|
)
|
||||||
|
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange
|
||||||
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||||
import math
|
import math
|
||||||
|
|
||||||
@ -118,35 +119,29 @@ class MambaBlock(nn.Module):
|
|||||||
_xz = self.in_proj(hidden_states)
|
_xz = self.in_proj(hidden_states)
|
||||||
_x, _z = _xz.chunk(2, dim=-1) # (B D)
|
_x, _z = _xz.chunk(2, dim=-1) # (B D)
|
||||||
conv_state_new = torch.cat([conv_state, _x.transpose(1,2)], dim=-1)
|
conv_state_new = torch.cat([conv_state, _x.transpose(1,2)], dim=-1)
|
||||||
conv_out = causal_conv1d_fn( x=conv_state_new, weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), bias=self.conv1d.bias, activation=self.activation)
|
conv_out = causal_conv1d_fn(
|
||||||
|
x=conv_state_new,
|
||||||
|
weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)),
|
||||||
|
bias=self.conv1d.bias,
|
||||||
|
activation=self.activation
|
||||||
|
)
|
||||||
conv_state = conv_state_new[:, :, 1:]
|
conv_state = conv_state_new[:, :, 1:]
|
||||||
|
|
||||||
bsz, seqlen, dim = hidden_states.shape
|
bsz, seqlen, dim = hidden_states.shape
|
||||||
# empty output tensor for the loop
|
|
||||||
output_tensor = torch.zeros(
|
output_tensor = torch.zeros(
|
||||||
(bsz, seqlen, dim),
|
(bsz, seqlen, dim),
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
dtype=hidden_states.dtype
|
dtype=hidden_states.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in range(0, bsz):
|
for i in range(0, bsz):
|
||||||
x = conv_out[i:i+1,:,-1]
|
x = conv_out[i:i+1,:,-1]
|
||||||
z = _z[i:i+1, -1, :]
|
z = _z[i:i+1, -1, :]
|
||||||
x_db = self.x_proj(x)
|
x_db = self.x_proj(x)
|
||||||
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||||
dt = self.dt_proj_no_bias(dt)
|
df = self.dt_proj_no_bias(x)
|
||||||
dt = F.softplus(dt + self.dt_proj.bias).view((dt.size(1), -1))
|
y = selective_state_update(
|
||||||
dA = torch.exp(dt * self.negA)
|
ssm_state[i:i+1,:,:], x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
|
||||||
dB = dt * B.view(-1, B.size(-1))
|
)
|
||||||
x_shape = (-1, x.size(-1), 1)
|
out = self.out_proj(y)
|
||||||
ssm_state[i] = (ssm_state[i] * dA + dB * x.view(x_shape))
|
|
||||||
c_shape = (C.size(0), C.size(1), -1)
|
|
||||||
out_mm_shape = (C.size(0), -1)
|
|
||||||
out = torch.matmul(ssm_state[i].to(C.dtype), C.view(c_shape)).view(out_mm_shape)
|
|
||||||
# in-place ops
|
|
||||||
out.add_((x * self.D).to(out.dtype))
|
|
||||||
out.mul_(F.silu(z))
|
|
||||||
out = self.out_proj(out)
|
|
||||||
output_tensor[i] = out
|
output_tensor[i] = out
|
||||||
|
|
||||||
return output_tensor, conv_state, ssm_state
|
return output_tensor, conv_state, ssm_state
|
||||||
|
@ -344,12 +344,8 @@ class MambaBatch(Batch):
|
|||||||
for i in range(n_blocks):
|
for i in range(n_blocks):
|
||||||
conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i]
|
conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i]
|
||||||
batch_size = batch.inference_params.max_batch_size
|
batch_size = batch.inference_params.max_batch_size
|
||||||
try:
|
|
||||||
inference_params.key_value_memory_dict[i][0][current_batch:current_batch + batch_size] = conv_state
|
inference_params.key_value_memory_dict[i][0][current_batch:current_batch + batch_size] = conv_state
|
||||||
inference_params.key_value_memory_dict[i][1][current_batch:current_batch + batch_size] = ssm_state
|
inference_params.key_value_memory_dict[i][1][current_batch:current_batch + batch_size] = ssm_state
|
||||||
except Exception:
|
|
||||||
import ipdb;ipdb.set_trace()
|
|
||||||
pass
|
|
||||||
inference_params.lengths_per_sample[current_batch: current_batch + batch_size] = batch.inference_params.lengths_per_sample
|
inference_params.lengths_per_sample[current_batch: current_batch + batch_size] = batch.inference_params.lengths_per_sample
|
||||||
current_batch += batch_size
|
current_batch += batch_size
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user