mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix: improve step to use batch
This commit is contained in:
parent
a4f1916a56
commit
63bc4c59d4
@ -115,16 +115,61 @@ class MambaBlock(nn.Module):
|
|||||||
return attn_outputs, conv_state, last_state
|
return attn_outputs, conv_state, last_state
|
||||||
|
|
||||||
def step(self, hidden_states, conv_state, ssm_state):
|
def step(self, hidden_states, conv_state, ssm_state):
|
||||||
# only support decoding with 1 token at a time
|
_xz = self.in_proj(hidden_states)
|
||||||
xz = self.in_proj(hidden_states.view((1, -1)))
|
_x, _z = _xz.chunk(2, dim=-1) # (B D)
|
||||||
x, z = xz.chunk(2, dim=-1) # (B D)
|
conv_state_new = torch.cat([conv_state, _x.transpose(1,2)], dim=-1)
|
||||||
x = causal_conv1d_update(
|
conv_out = causal_conv1d_fn( x=conv_state_new, weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), bias=self.conv1d.bias, activation=self.activation)
|
||||||
x,
|
conv_state = conv_state_new[:, :, 1:]
|
||||||
conv_state,
|
|
||||||
self.conv1d.weight.view(self.conv1d.weight.size(0), -1),
|
handle_batched = False
|
||||||
self.conv1d.bias,
|
if handle_batched:
|
||||||
self.activation,
|
bsz, seqlen, dim = hidden_states.shape
|
||||||
)
|
# empty output tensor for the loop
|
||||||
|
output_tensor = torch.zeros(
|
||||||
|
(bsz, seqlen, dim),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(0, bsz):
|
||||||
|
x = conv_out[:,:,i]
|
||||||
|
z = _z[:, i, :]
|
||||||
|
x_db = self.x_proj(x)
|
||||||
|
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||||
|
dt = self.dt_proj_no_bias(dt)
|
||||||
|
dt = F.softplus(dt + self.dt_proj.bias).view((dt.size(1), -1))
|
||||||
|
dA = torch.exp(dt * self.negA)
|
||||||
|
dB = dt * B.view(-1, B.size(-1))
|
||||||
|
x_shape = (-1, x.size(-1), 1)
|
||||||
|
# ssm_state = (ssm_state * dA + dB * x.view(x_shape))
|
||||||
|
ssm_state[i] = (ssm_state[i] * dA + dB * x)#.view(x_shape))
|
||||||
|
c_shape = (C.size(0), C.size(1), -1)
|
||||||
|
out_mm_shape = (C.size(0), -1)
|
||||||
|
out = torch.matmul(ssm_state.to(C.dtype), C.view(c_shape)).view(out_mm_shape)
|
||||||
|
# in-place ops
|
||||||
|
out.add_((x * self.D).to(out.dtype))
|
||||||
|
out.mul_(F.silu(z))
|
||||||
|
out = self.out_proj(out)
|
||||||
|
output_tensor[i] = out
|
||||||
|
|
||||||
|
return output_tensor, conv_state, ssm_state
|
||||||
|
|
||||||
|
# TODO: remove this code only left for reference
|
||||||
|
# # only support decoding with 1 token at a time
|
||||||
|
# xz = self.in_proj(hidden_states.view((1, -1)))
|
||||||
|
# x, z = xz.chunk(2, dim=-1) # (B D)
|
||||||
|
# x = causal_conv1d_update(
|
||||||
|
# x,
|
||||||
|
# conv_state,
|
||||||
|
# self.conv1d.weight.view(self.conv1d.weight.size(0), -1),
|
||||||
|
# self.conv1d.bias,
|
||||||
|
# self.activation,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# TODO: prefer using batched logic in all cases
|
||||||
|
# this just pulls the last element of the batch
|
||||||
|
x = conv_out[:,:,-1]
|
||||||
|
z = _z[:, -1, :]
|
||||||
x_db = self.x_proj(x)
|
x_db = self.x_proj(x)
|
||||||
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||||
dt = self.dt_proj_no_bias(dt)
|
dt = self.dt_proj_no_bias(dt)
|
||||||
|
Loading…
Reference in New Issue
Block a user