diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py index 3180d686..0c019212 100644 --- a/server/text_generation_server/models/custom_modeling/mamba_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -1,8 +1,7 @@ import torch import torch.distributed -from mamba_ssm.ops.triton.selective_state_update import selective_state_update -from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn from mamba_ssm.utils.generation import InferenceParams from torch import nn from typing import Optional, Tuple, Any @@ -60,6 +59,7 @@ class MambaBlock(nn.Module): self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False) self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False) self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True) + self.dt_proj_no_bias = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=False) self.out_proj = FastLinear.load(config, f"{prefix}.out_proj", weights, bias=False) self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True) self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float()) @@ -116,25 +116,32 @@ class MambaBlock(nn.Module): def step(self, hidden_states, conv_state, ssm_state): # only support decoding with 1 token at a time - xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) + xz = self.in_proj(hidden_states.view((1, -1))) x, z = xz.chunk(2, dim=-1) # (B D) x = causal_conv1d_update( x, conv_state, - self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), + self.conv1d.weight.view(self.conv1d.weight.size(0), -1), self.conv1d.bias, self.activation, ) - x_db = self.x_proj(x) # (B dt_rank+2*d_state) + x_db = self.x_proj(x) dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) - # Don't add dt_bias here - dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) - y = selective_state_update( - ssm_state, x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True - ) - out = self.out_proj(y) - # conv and ssm are updated in place but we return them to make the control flow more explicit - return out.unsqueeze(1), conv_state, ssm_state + dt = self.dt_proj_no_bias(dt) + dt = F.softplus(dt + self.dt_proj.bias).view((dt.size(1), -1)) + dA = torch.exp(dt * self.negA) + dB = dt * B.view(-1, B.size(-1)) + x_shape = (-1, x.size(-1), 1) + ssm_state = (ssm_state * dA + dB * x.view(x_shape)) + c_shape = (C.size(0), C.size(1), -1) + out_mm_shape = (C.size(0), -1) + out = torch.matmul(ssm_state.to(C.dtype), C.view(c_shape)).view(out_mm_shape) + # in-place ops + out.add_((x * self.D).to(out.dtype)) + out.mul_(F.silu(z)) + out = self.out_proj(out) + return out.view((1, -1, out.size(-1))), conv_state, ssm_state + class ResidualBlock(nn.Module): def __init__(self, layer_id, config, weights):