diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py index 0c019212..410a254a 100644 --- a/server/text_generation_server/models/custom_modeling/mamba_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -115,16 +115,61 @@ class MambaBlock(nn.Module): return attn_outputs, conv_state, last_state def step(self, hidden_states, conv_state, ssm_state): - # only support decoding with 1 token at a time - xz = self.in_proj(hidden_states.view((1, -1))) - x, z = xz.chunk(2, dim=-1) # (B D) - x = causal_conv1d_update( - x, - conv_state, - self.conv1d.weight.view(self.conv1d.weight.size(0), -1), - self.conv1d.bias, - self.activation, - ) + _xz = self.in_proj(hidden_states) + _x, _z = _xz.chunk(2, dim=-1) # (B D) + conv_state_new = torch.cat([conv_state, _x.transpose(1,2)], dim=-1) + conv_out = causal_conv1d_fn( x=conv_state_new, weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), bias=self.conv1d.bias, activation=self.activation) + conv_state = conv_state_new[:, :, 1:] + + handle_batched = False + if handle_batched: + bsz, seqlen, dim = hidden_states.shape + # empty output tensor for the loop + output_tensor = torch.zeros( + (bsz, seqlen, dim), + device=hidden_states.device, + dtype=hidden_states.dtype + ) + + for i in range(0, bsz): + x = conv_out[:,:,i] + z = _z[:, i, :] + x_db = self.x_proj(x) + dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt = self.dt_proj_no_bias(dt) + dt = F.softplus(dt + self.dt_proj.bias).view((dt.size(1), -1)) + dA = torch.exp(dt * self.negA) + dB = dt * B.view(-1, B.size(-1)) + x_shape = (-1, x.size(-1), 1) + # ssm_state = (ssm_state * dA + dB * x.view(x_shape)) + ssm_state[i] = (ssm_state[i] * dA + dB * x)#.view(x_shape)) + c_shape = (C.size(0), C.size(1), -1) + out_mm_shape = (C.size(0), -1) + out = torch.matmul(ssm_state.to(C.dtype), C.view(c_shape)).view(out_mm_shape) + # in-place ops + out.add_((x * self.D).to(out.dtype)) + out.mul_(F.silu(z)) + out = self.out_proj(out) + output_tensor[i] = out + + return output_tensor, conv_state, ssm_state + + # TODO: remove this code only left for reference + # # only support decoding with 1 token at a time + # xz = self.in_proj(hidden_states.view((1, -1))) + # x, z = xz.chunk(2, dim=-1) # (B D) + # x = causal_conv1d_update( + # x, + # conv_state, + # self.conv1d.weight.view(self.conv1d.weight.size(0), -1), + # self.conv1d.bias, + # self.activation, + # ) + + # TODO: prefer using batched logic in all cases + # this just pulls the last element of the batch + x = conv_out[:,:,-1] + z = _z[:, -1, :] x_db = self.x_proj(x) dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt = self.dt_proj_no_bias(dt)