diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py index 567927f6..c515a5ca 100644 --- a/server/text_generation_server/models/custom_modeling/mamba_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -1,6 +1,7 @@ import torch import torch.distributed +from mamba_ssm.ops.triton.selective_state_update import selective_state_update from mamba_ssm.ops.selective_scan_interface import selective_scan_fn from mamba_ssm.utils.generation import InferenceParams from torch import nn @@ -14,7 +15,7 @@ from text_generation_server.utils.layers import ( FastLinear, ) -from einops import rearrange, repeat +from einops import rearrange from causal_conv1d import causal_conv1d_fn, causal_conv1d_update import math @@ -118,35 +119,29 @@ class MambaBlock(nn.Module): _xz = self.in_proj(hidden_states) _x, _z = _xz.chunk(2, dim=-1) # (B D) conv_state_new = torch.cat([conv_state, _x.transpose(1,2)], dim=-1) - conv_out = causal_conv1d_fn( x=conv_state_new, weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), bias=self.conv1d.bias, activation=self.activation) + conv_out = causal_conv1d_fn( + x=conv_state_new, + weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), + bias=self.conv1d.bias, + activation=self.activation + ) conv_state = conv_state_new[:, :, 1:] - bsz, seqlen, dim = hidden_states.shape - # empty output tensor for the loop output_tensor = torch.zeros( (bsz, seqlen, dim), device=hidden_states.device, dtype=hidden_states.dtype ) - for i in range(0, bsz): x = conv_out[i:i+1,:,-1] z = _z[i:i+1, -1, :] x_db = self.x_proj(x) dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) - dt = self.dt_proj_no_bias(dt) - dt = F.softplus(dt + self.dt_proj.bias).view((dt.size(1), -1)) - dA = torch.exp(dt * self.negA) - dB = dt * B.view(-1, B.size(-1)) - x_shape = (-1, x.size(-1), 1) - ssm_state[i] = (ssm_state[i] * dA + dB * x.view(x_shape)) - c_shape = (C.size(0), C.size(1), -1) - out_mm_shape = (C.size(0), -1) - out = torch.matmul(ssm_state[i].to(C.dtype), C.view(c_shape)).view(out_mm_shape) - # in-place ops - out.add_((x * self.D).to(out.dtype)) - out.mul_(F.silu(z)) - out = self.out_proj(out) + df = self.dt_proj_no_bias(x) + y = selective_state_update( + ssm_state[i:i+1,:,:], x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True + ) + out = self.out_proj(y) output_tensor[i] = out return output_tensor, conv_state, ssm_state diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 4750d90a..af09f70f 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -344,12 +344,8 @@ class MambaBatch(Batch): for i in range(n_blocks): conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i] batch_size = batch.inference_params.max_batch_size - try: - inference_params.key_value_memory_dict[i][0][current_batch:current_batch + batch_size] = conv_state - inference_params.key_value_memory_dict[i][1][current_batch:current_batch + batch_size] = ssm_state - except Exception: - import ipdb;ipdb.set_trace() - pass + inference_params.key_value_memory_dict[i][0][current_batch:current_batch + batch_size] = conv_state + inference_params.key_value_memory_dict[i][1][current_batch:current_batch + batch_size] = ssm_state inference_params.lengths_per_sample[current_batch: current_batch + batch_size] = batch.inference_params.lengths_per_sample current_batch += batch_size