feat: support batching

This commit is contained in:
drbh 2024-02-06 01:22:25 +00:00
parent 63bc4c59d4
commit 3caa9b9cb7
4 changed files with 373 additions and 71 deletions

View File

@ -0,0 +1,338 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 5089,
"logprob": null,
"text": "Test"
},
{
"id": 2748,
"logprob": -9.7265625,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 187,
"logprob": -2.4746094,
"special": false,
"text": "\n"
},
{
"id": 187,
"logprob": -1.3857422,
"special": false,
"text": "\n"
},
{
"id": 510,
"logprob": -2.703125,
"special": false,
"text": "The"
},
{
"id": 806,
"logprob": -4.1992188,
"special": false,
"text": " first"
},
{
"id": 2181,
"logprob": -2.703125,
"special": false,
"text": " thing"
},
{
"id": 309,
"logprob": -1.4160156,
"special": false,
"text": " I"
},
{
"id": 8344,
"logprob": -1.6171875,
"special": false,
"text": " noticed"
},
{
"id": 369,
"logprob": -1.0039062,
"special": false,
"text": " was"
},
{
"id": 326,
"logprob": -0.8823242,
"special": false,
"text": " that"
},
{
"id": 253,
"logprob": -1.3173828,
"special": false,
"text": " the"
}
],
"top_tokens": null
},
"generated_text": "\n\nThe first thing I noticed was that the"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 5089,
"logprob": null,
"text": "Test"
},
{
"id": 2748,
"logprob": -9.7265625,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 187,
"logprob": -2.4941406,
"special": false,
"text": "\n"
},
{
"id": 187,
"logprob": -1.3857422,
"special": false,
"text": "\n"
},
{
"id": 510,
"logprob": -2.703125,
"special": false,
"text": "The"
},
{
"id": 806,
"logprob": -4.1992188,
"special": false,
"text": " first"
},
{
"id": 2181,
"logprob": -2.703125,
"special": false,
"text": " thing"
},
{
"id": 309,
"logprob": -1.4160156,
"special": false,
"text": " I"
},
{
"id": 8344,
"logprob": -1.6171875,
"special": false,
"text": " noticed"
},
{
"id": 369,
"logprob": -1.0039062,
"special": false,
"text": " was"
},
{
"id": 326,
"logprob": -0.8823242,
"special": false,
"text": " that"
},
{
"id": 253,
"logprob": -1.3173828,
"special": false,
"text": " the"
}
],
"top_tokens": null
},
"generated_text": "\n\nThe first thing I noticed was that the"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 5089,
"logprob": null,
"text": "Test"
},
{
"id": 2748,
"logprob": -9.7265625,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 187,
"logprob": -2.4941406,
"special": false,
"text": "\n"
},
{
"id": 187,
"logprob": -1.3857422,
"special": false,
"text": "\n"
},
{
"id": 510,
"logprob": -2.703125,
"special": false,
"text": "The"
},
{
"id": 806,
"logprob": -4.1992188,
"special": false,
"text": " first"
},
{
"id": 2181,
"logprob": -2.703125,
"special": false,
"text": " thing"
},
{
"id": 309,
"logprob": -1.4160156,
"special": false,
"text": " I"
},
{
"id": 8344,
"logprob": -1.6171875,
"special": false,
"text": " noticed"
},
{
"id": 369,
"logprob": -1.0039062,
"special": false,
"text": " was"
},
{
"id": 326,
"logprob": -0.8823242,
"special": false,
"text": " that"
},
{
"id": 253,
"logprob": -1.3173828,
"special": false,
"text": " the"
}
],
"top_tokens": null
},
"generated_text": "\n\nThe first thing I noticed was that the"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 5089,
"logprob": null,
"text": "Test"
},
{
"id": 2748,
"logprob": -9.7265625,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 187,
"logprob": -2.4941406,
"special": false,
"text": "\n"
},
{
"id": 187,
"logprob": -1.3857422,
"special": false,
"text": "\n"
},
{
"id": 510,
"logprob": -2.703125,
"special": false,
"text": "The"
},
{
"id": 806,
"logprob": -4.1992188,
"special": false,
"text": " first"
},
{
"id": 2181,
"logprob": -2.703125,
"special": false,
"text": " thing"
},
{
"id": 309,
"logprob": -1.4160156,
"special": false,
"text": " I"
},
{
"id": 8344,
"logprob": -1.6171875,
"special": false,
"text": " noticed"
},
{
"id": 369,
"logprob": -1.0039062,
"special": false,
"text": " was"
},
{
"id": 326,
"logprob": -0.8823242,
"special": false,
"text": " that"
},
{
"id": 253,
"logprob": -1.3173828,
"special": false,
"text": " the"
}
],
"top_tokens": null
},
"generated_text": "\n\nThe first thing I noticed was that the"
}
]

View File

@ -47,13 +47,12 @@ async def test_fused_kernel_mamba_all_params(fused_kernel_mamba, response_snapsh
assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in" assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in"
assert response == response_snapshot assert response == response_snapshot
# TODO: Fix batching @pytest.mark.asyncio
# @pytest.mark.asyncio @pytest.mark.private
# @pytest.mark.private async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
# async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot): responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4)
# responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=2)
# assert len(responses) == 4 assert len(responses) == 4
# assert all([r.generated_text == responses[0].generated_text for r in responses]) assert all([r.generated_text == responses[0].generated_text for r in responses])
# assert responses == response_snapshot assert responses == response_snapshot

View File

@ -121,71 +121,36 @@ class MambaBlock(nn.Module):
conv_out = causal_conv1d_fn( x=conv_state_new, weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), bias=self.conv1d.bias, activation=self.activation) conv_out = causal_conv1d_fn( x=conv_state_new, weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), bias=self.conv1d.bias, activation=self.activation)
conv_state = conv_state_new[:, :, 1:] conv_state = conv_state_new[:, :, 1:]
handle_batched = False bsz, seqlen, dim = hidden_states.shape
if handle_batched: # empty output tensor for the loop
bsz, seqlen, dim = hidden_states.shape output_tensor = torch.zeros(
# empty output tensor for the loop (bsz, seqlen, dim),
output_tensor = torch.zeros( device=hidden_states.device,
(bsz, seqlen, dim), dtype=hidden_states.dtype
device=hidden_states.device, )
dtype=hidden_states.dtype
)
for i in range(0, bsz): for i in range(0, bsz):
x = conv_out[:,:,i] x = conv_out[i:i+1,:,-1]
z = _z[:, i, :] z = _z[i:i+1, -1, :]
x_db = self.x_proj(x) x_db = self.x_proj(x)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = self.dt_proj_no_bias(dt) dt = self.dt_proj_no_bias(dt)
dt = F.softplus(dt + self.dt_proj.bias).view((dt.size(1), -1)) dt = F.softplus(dt + self.dt_proj.bias).view((dt.size(1), -1))
dA = torch.exp(dt * self.negA) dA = torch.exp(dt * self.negA)
dB = dt * B.view(-1, B.size(-1)) dB = dt * B.view(-1, B.size(-1))
x_shape = (-1, x.size(-1), 1) x_shape = (-1, x.size(-1), 1)
# ssm_state = (ssm_state * dA + dB * x.view(x_shape)) ssm_state[i] = (ssm_state[i] * dA + dB * x.view(x_shape))
ssm_state[i] = (ssm_state[i] * dA + dB * x)#.view(x_shape)) c_shape = (C.size(0), C.size(1), -1)
c_shape = (C.size(0), C.size(1), -1) out_mm_shape = (C.size(0), -1)
out_mm_shape = (C.size(0), -1) out = torch.matmul(ssm_state[i].to(C.dtype), C.view(c_shape)).view(out_mm_shape)
out = torch.matmul(ssm_state.to(C.dtype), C.view(c_shape)).view(out_mm_shape) # in-place ops
# in-place ops out.add_((x * self.D).to(out.dtype))
out.add_((x * self.D).to(out.dtype)) out.mul_(F.silu(z))
out.mul_(F.silu(z)) out = self.out_proj(out)
out = self.out_proj(out) output_tensor[i] = out
output_tensor[i] = out
return output_tensor, conv_state, ssm_state return output_tensor, conv_state, ssm_state
# TODO: remove this code only left for reference
# # only support decoding with 1 token at a time
# xz = self.in_proj(hidden_states.view((1, -1)))
# x, z = xz.chunk(2, dim=-1) # (B D)
# x = causal_conv1d_update(
# x,
# conv_state,
# self.conv1d.weight.view(self.conv1d.weight.size(0), -1),
# self.conv1d.bias,
# self.activation,
# )
# TODO: prefer using batched logic in all cases
# this just pulls the last element of the batch
x = conv_out[:,:,-1]
z = _z[:, -1, :]
x_db = self.x_proj(x)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = self.dt_proj_no_bias(dt)
dt = F.softplus(dt + self.dt_proj.bias).view((dt.size(1), -1))
dA = torch.exp(dt * self.negA)
dB = dt * B.view(-1, B.size(-1))
x_shape = (-1, x.size(-1), 1)
ssm_state = (ssm_state * dA + dB * x.view(x_shape))
c_shape = (C.size(0), C.size(1), -1)
out_mm_shape = (C.size(0), -1)
out = torch.matmul(ssm_state.to(C.dtype), C.view(c_shape)).view(out_mm_shape)
# in-place ops
out.add_((x * self.D).to(out.dtype))
out.mul_(F.silu(z))
out = self.out_proj(out)
return out.view((1, -1, out.size(-1))), conv_state, ssm_state
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):

View File

@ -557,7 +557,7 @@ class Mamba(Model):
top_tokens = None top_tokens = None
generation = Generation( generation = Generation(
batch.batch_id, request.id,
prefill_tokens, prefill_tokens,
Tokens( Tokens(
[next_token_id_squeezed], [next_token_id_squeezed],