feat: support batching

This commit is contained in:
drbh 2024-02-06 01:22:25 +00:00
parent 63bc4c59d4
commit 3caa9b9cb7
4 changed files with 373 additions and 71 deletions

View File

@ -0,0 +1,338 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 5089,
"logprob": null,
"text": "Test"
},
{
"id": 2748,
"logprob": -9.7265625,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 187,
"logprob": -2.4746094,
"special": false,
"text": "\n"
},
{
"id": 187,
"logprob": -1.3857422,
"special": false,
"text": "\n"
},
{
"id": 510,
"logprob": -2.703125,
"special": false,
"text": "The"
},
{
"id": 806,
"logprob": -4.1992188,
"special": false,
"text": " first"
},
{
"id": 2181,
"logprob": -2.703125,
"special": false,
"text": " thing"
},
{
"id": 309,
"logprob": -1.4160156,
"special": false,
"text": " I"
},
{
"id": 8344,
"logprob": -1.6171875,
"special": false,
"text": " noticed"
},
{
"id": 369,
"logprob": -1.0039062,
"special": false,
"text": " was"
},
{
"id": 326,
"logprob": -0.8823242,
"special": false,
"text": " that"
},
{
"id": 253,
"logprob": -1.3173828,
"special": false,
"text": " the"
}
],
"top_tokens": null
},
"generated_text": "\n\nThe first thing I noticed was that the"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 5089,
"logprob": null,
"text": "Test"
},
{
"id": 2748,
"logprob": -9.7265625,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 187,
"logprob": -2.4941406,
"special": false,
"text": "\n"
},
{
"id": 187,
"logprob": -1.3857422,
"special": false,
"text": "\n"
},
{
"id": 510,
"logprob": -2.703125,
"special": false,
"text": "The"
},
{
"id": 806,
"logprob": -4.1992188,
"special": false,
"text": " first"
},
{
"id": 2181,
"logprob": -2.703125,
"special": false,
"text": " thing"
},
{
"id": 309,
"logprob": -1.4160156,
"special": false,
"text": " I"
},
{
"id": 8344,
"logprob": -1.6171875,
"special": false,
"text": " noticed"
},
{
"id": 369,
"logprob": -1.0039062,
"special": false,
"text": " was"
},
{
"id": 326,
"logprob": -0.8823242,
"special": false,
"text": " that"
},
{
"id": 253,
"logprob": -1.3173828,
"special": false,
"text": " the"
}
],
"top_tokens": null
},
"generated_text": "\n\nThe first thing I noticed was that the"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 5089,
"logprob": null,
"text": "Test"
},
{
"id": 2748,
"logprob": -9.7265625,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 187,
"logprob": -2.4941406,
"special": false,
"text": "\n"
},
{
"id": 187,
"logprob": -1.3857422,
"special": false,
"text": "\n"
},
{
"id": 510,
"logprob": -2.703125,
"special": false,
"text": "The"
},
{
"id": 806,
"logprob": -4.1992188,
"special": false,
"text": " first"
},
{
"id": 2181,
"logprob": -2.703125,
"special": false,
"text": " thing"
},
{
"id": 309,
"logprob": -1.4160156,
"special": false,
"text": " I"
},
{
"id": 8344,
"logprob": -1.6171875,
"special": false,
"text": " noticed"
},
{
"id": 369,
"logprob": -1.0039062,
"special": false,
"text": " was"
},
{
"id": 326,
"logprob": -0.8823242,
"special": false,
"text": " that"
},
{
"id": 253,
"logprob": -1.3173828,
"special": false,
"text": " the"
}
],
"top_tokens": null
},
"generated_text": "\n\nThe first thing I noticed was that the"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 5089,
"logprob": null,
"text": "Test"
},
{
"id": 2748,
"logprob": -9.7265625,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 187,
"logprob": -2.4941406,
"special": false,
"text": "\n"
},
{
"id": 187,
"logprob": -1.3857422,
"special": false,
"text": "\n"
},
{
"id": 510,
"logprob": -2.703125,
"special": false,
"text": "The"
},
{
"id": 806,
"logprob": -4.1992188,
"special": false,
"text": " first"
},
{
"id": 2181,
"logprob": -2.703125,
"special": false,
"text": " thing"
},
{
"id": 309,
"logprob": -1.4160156,
"special": false,
"text": " I"
},
{
"id": 8344,
"logprob": -1.6171875,
"special": false,
"text": " noticed"
},
{
"id": 369,
"logprob": -1.0039062,
"special": false,
"text": " was"
},
{
"id": 326,
"logprob": -0.8823242,
"special": false,
"text": " that"
},
{
"id": 253,
"logprob": -1.3173828,
"special": false,
"text": " the"
}
],
"top_tokens": null
},
"generated_text": "\n\nThe first thing I noticed was that the"
}
]

View File

@ -47,13 +47,12 @@ async def test_fused_kernel_mamba_all_params(fused_kernel_mamba, response_snapsh
assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in"
assert response == response_snapshot
# TODO: Fix batching
# @pytest.mark.asyncio
# @pytest.mark.private
# async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
# responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=2)
@pytest.mark.asyncio
@pytest.mark.private
async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4)
# assert len(responses) == 4
# assert all([r.generated_text == responses[0].generated_text for r in responses])
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
# assert responses == response_snapshot
assert responses == response_snapshot

View File

@ -121,8 +121,6 @@ class MambaBlock(nn.Module):
conv_out = causal_conv1d_fn( x=conv_state_new, weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), bias=self.conv1d.bias, activation=self.activation)
conv_state = conv_state_new[:, :, 1:]
handle_batched = False
if handle_batched:
bsz, seqlen, dim = hidden_states.shape
# empty output tensor for the loop
output_tensor = torch.zeros(
@ -132,8 +130,8 @@ class MambaBlock(nn.Module):
)
for i in range(0, bsz):
x = conv_out[:,:,i]
z = _z[:, i, :]
x = conv_out[i:i+1,:,-1]
z = _z[i:i+1, -1, :]
x_db = self.x_proj(x)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = self.dt_proj_no_bias(dt)
@ -141,11 +139,10 @@ class MambaBlock(nn.Module):
dA = torch.exp(dt * self.negA)
dB = dt * B.view(-1, B.size(-1))
x_shape = (-1, x.size(-1), 1)
# ssm_state = (ssm_state * dA + dB * x.view(x_shape))
ssm_state[i] = (ssm_state[i] * dA + dB * x)#.view(x_shape))
ssm_state[i] = (ssm_state[i] * dA + dB * x.view(x_shape))
c_shape = (C.size(0), C.size(1), -1)
out_mm_shape = (C.size(0), -1)
out = torch.matmul(ssm_state.to(C.dtype), C.view(c_shape)).view(out_mm_shape)
out = torch.matmul(ssm_state[i].to(C.dtype), C.view(c_shape)).view(out_mm_shape)
# in-place ops
out.add_((x * self.D).to(out.dtype))
out.mul_(F.silu(z))
@ -154,38 +151,6 @@ class MambaBlock(nn.Module):
return output_tensor, conv_state, ssm_state
# TODO: remove this code only left for reference
# # only support decoding with 1 token at a time
# xz = self.in_proj(hidden_states.view((1, -1)))
# x, z = xz.chunk(2, dim=-1) # (B D)
# x = causal_conv1d_update(
# x,
# conv_state,
# self.conv1d.weight.view(self.conv1d.weight.size(0), -1),
# self.conv1d.bias,
# self.activation,
# )
# TODO: prefer using batched logic in all cases
# this just pulls the last element of the batch
x = conv_out[:,:,-1]
z = _z[:, -1, :]
x_db = self.x_proj(x)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = self.dt_proj_no_bias(dt)
dt = F.softplus(dt + self.dt_proj.bias).view((dt.size(1), -1))
dA = torch.exp(dt * self.negA)
dB = dt * B.view(-1, B.size(-1))
x_shape = (-1, x.size(-1), 1)
ssm_state = (ssm_state * dA + dB * x.view(x_shape))
c_shape = (C.size(0), C.size(1), -1)
out_mm_shape = (C.size(0), -1)
out = torch.matmul(ssm_state.to(C.dtype), C.view(c_shape)).view(out_mm_shape)
# in-place ops
out.add_((x * self.D).to(out.dtype))
out.mul_(F.silu(z))
out = self.out_proj(out)
return out.view((1, -1, out.size(-1))), conv_state, ssm_state
class ResidualBlock(nn.Module):

View File

@ -557,7 +557,7 @@ class Mamba(Model):
top_tokens = None
generation = Generation(
batch.batch_id,
request.id,
prefill_tokens,
Tokens(
[next_token_id_squeezed],