feat: prefer triton ops and batch conv

This commit is contained in:
drbh 2024-02-06 20:38:28 +00:00
parent 8319e854c8
commit 5e102183d8
2 changed files with 15 additions and 24 deletions

View File

@ -1,6 +1,7 @@
import torch import torch
import torch.distributed import torch.distributed
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from mamba_ssm.utils.generation import InferenceParams from mamba_ssm.utils.generation import InferenceParams
from torch import nn from torch import nn
@ -14,7 +15,7 @@ from text_generation_server.utils.layers import (
FastLinear, FastLinear,
) )
from einops import rearrange, repeat from einops import rearrange
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
import math import math
@ -118,35 +119,29 @@ class MambaBlock(nn.Module):
_xz = self.in_proj(hidden_states) _xz = self.in_proj(hidden_states)
_x, _z = _xz.chunk(2, dim=-1) # (B D) _x, _z = _xz.chunk(2, dim=-1) # (B D)
conv_state_new = torch.cat([conv_state, _x.transpose(1,2)], dim=-1) conv_state_new = torch.cat([conv_state, _x.transpose(1,2)], dim=-1)
conv_out = causal_conv1d_fn( x=conv_state_new, weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), bias=self.conv1d.bias, activation=self.activation) conv_out = causal_conv1d_fn(
x=conv_state_new,
weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)),
bias=self.conv1d.bias,
activation=self.activation
)
conv_state = conv_state_new[:, :, 1:] conv_state = conv_state_new[:, :, 1:]
bsz, seqlen, dim = hidden_states.shape bsz, seqlen, dim = hidden_states.shape
# empty output tensor for the loop
output_tensor = torch.zeros( output_tensor = torch.zeros(
(bsz, seqlen, dim), (bsz, seqlen, dim),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype dtype=hidden_states.dtype
) )
for i in range(0, bsz): for i in range(0, bsz):
x = conv_out[i:i+1,:,-1] x = conv_out[i:i+1,:,-1]
z = _z[i:i+1, -1, :] z = _z[i:i+1, -1, :]
x_db = self.x_proj(x) x_db = self.x_proj(x)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = self.dt_proj_no_bias(dt) df = self.dt_proj_no_bias(x)
dt = F.softplus(dt + self.dt_proj.bias).view((dt.size(1), -1)) y = selective_state_update(
dA = torch.exp(dt * self.negA) ssm_state[i:i+1,:,:], x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
dB = dt * B.view(-1, B.size(-1)) )
x_shape = (-1, x.size(-1), 1) out = self.out_proj(y)
ssm_state[i] = (ssm_state[i] * dA + dB * x.view(x_shape))
c_shape = (C.size(0), C.size(1), -1)
out_mm_shape = (C.size(0), -1)
out = torch.matmul(ssm_state[i].to(C.dtype), C.view(c_shape)).view(out_mm_shape)
# in-place ops
out.add_((x * self.D).to(out.dtype))
out.mul_(F.silu(z))
out = self.out_proj(out)
output_tensor[i] = out output_tensor[i] = out
return output_tensor, conv_state, ssm_state return output_tensor, conv_state, ssm_state

View File

@ -344,12 +344,8 @@ class MambaBatch(Batch):
for i in range(n_blocks): for i in range(n_blocks):
conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i] conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i]
batch_size = batch.inference_params.max_batch_size batch_size = batch.inference_params.max_batch_size
try:
inference_params.key_value_memory_dict[i][0][current_batch:current_batch + batch_size] = conv_state inference_params.key_value_memory_dict[i][0][current_batch:current_batch + batch_size] = conv_state
inference_params.key_value_memory_dict[i][1][current_batch:current_batch + batch_size] = ssm_state inference_params.key_value_memory_dict[i][1][current_batch:current_batch + batch_size] = ssm_state
except Exception:
import ipdb;ipdb.set_trace()
pass
inference_params.lengths_per_sample[current_batch: current_batch + batch_size] = batch.inference_params.lengths_per_sample inference_params.lengths_per_sample[current_batch: current_batch + batch_size] = batch.inference_params.lengths_per_sample
current_batch += batch_size current_batch += batch_size