mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
feat: support batching
This commit is contained in:
parent
63bc4c59d4
commit
3caa9b9cb7
@ -0,0 +1,338 @@
|
||||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 5089,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2748,
|
||||
"logprob": -9.7265625,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -2.4746094,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -1.3857422,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 510,
|
||||
"logprob": -2.703125,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
},
|
||||
{
|
||||
"id": 806,
|
||||
"logprob": -4.1992188,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 2181,
|
||||
"logprob": -2.703125,
|
||||
"special": false,
|
||||
"text": " thing"
|
||||
},
|
||||
{
|
||||
"id": 309,
|
||||
"logprob": -1.4160156,
|
||||
"special": false,
|
||||
"text": " I"
|
||||
},
|
||||
{
|
||||
"id": 8344,
|
||||
"logprob": -1.6171875,
|
||||
"special": false,
|
||||
"text": " noticed"
|
||||
},
|
||||
{
|
||||
"id": 369,
|
||||
"logprob": -1.0039062,
|
||||
"special": false,
|
||||
"text": " was"
|
||||
},
|
||||
{
|
||||
"id": 326,
|
||||
"logprob": -0.8823242,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 253,
|
||||
"logprob": -1.3173828,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nThe first thing I noticed was that the"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 5089,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2748,
|
||||
"logprob": -9.7265625,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -2.4941406,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -1.3857422,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 510,
|
||||
"logprob": -2.703125,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
},
|
||||
{
|
||||
"id": 806,
|
||||
"logprob": -4.1992188,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 2181,
|
||||
"logprob": -2.703125,
|
||||
"special": false,
|
||||
"text": " thing"
|
||||
},
|
||||
{
|
||||
"id": 309,
|
||||
"logprob": -1.4160156,
|
||||
"special": false,
|
||||
"text": " I"
|
||||
},
|
||||
{
|
||||
"id": 8344,
|
||||
"logprob": -1.6171875,
|
||||
"special": false,
|
||||
"text": " noticed"
|
||||
},
|
||||
{
|
||||
"id": 369,
|
||||
"logprob": -1.0039062,
|
||||
"special": false,
|
||||
"text": " was"
|
||||
},
|
||||
{
|
||||
"id": 326,
|
||||
"logprob": -0.8823242,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 253,
|
||||
"logprob": -1.3173828,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nThe first thing I noticed was that the"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 5089,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2748,
|
||||
"logprob": -9.7265625,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -2.4941406,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -1.3857422,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 510,
|
||||
"logprob": -2.703125,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
},
|
||||
{
|
||||
"id": 806,
|
||||
"logprob": -4.1992188,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 2181,
|
||||
"logprob": -2.703125,
|
||||
"special": false,
|
||||
"text": " thing"
|
||||
},
|
||||
{
|
||||
"id": 309,
|
||||
"logprob": -1.4160156,
|
||||
"special": false,
|
||||
"text": " I"
|
||||
},
|
||||
{
|
||||
"id": 8344,
|
||||
"logprob": -1.6171875,
|
||||
"special": false,
|
||||
"text": " noticed"
|
||||
},
|
||||
{
|
||||
"id": 369,
|
||||
"logprob": -1.0039062,
|
||||
"special": false,
|
||||
"text": " was"
|
||||
},
|
||||
{
|
||||
"id": 326,
|
||||
"logprob": -0.8823242,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 253,
|
||||
"logprob": -1.3173828,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nThe first thing I noticed was that the"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 5089,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2748,
|
||||
"logprob": -9.7265625,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -2.4941406,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -1.3857422,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 510,
|
||||
"logprob": -2.703125,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
},
|
||||
{
|
||||
"id": 806,
|
||||
"logprob": -4.1992188,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 2181,
|
||||
"logprob": -2.703125,
|
||||
"special": false,
|
||||
"text": " thing"
|
||||
},
|
||||
{
|
||||
"id": 309,
|
||||
"logprob": -1.4160156,
|
||||
"special": false,
|
||||
"text": " I"
|
||||
},
|
||||
{
|
||||
"id": 8344,
|
||||
"logprob": -1.6171875,
|
||||
"special": false,
|
||||
"text": " noticed"
|
||||
},
|
||||
{
|
||||
"id": 369,
|
||||
"logprob": -1.0039062,
|
||||
"special": false,
|
||||
"text": " was"
|
||||
},
|
||||
{
|
||||
"id": 326,
|
||||
"logprob": -0.8823242,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 253,
|
||||
"logprob": -1.3173828,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nThe first thing I noticed was that the"
|
||||
}
|
||||
]
|
@ -47,13 +47,12 @@ async def test_fused_kernel_mamba_all_params(fused_kernel_mamba, response_snapsh
|
||||
assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in"
|
||||
assert response == response_snapshot
|
||||
|
||||
# TODO: Fix batching
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.private
|
||||
# async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
|
||||
# responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=2)
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
|
||||
responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4)
|
||||
|
||||
# assert len(responses) == 4
|
||||
# assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
# assert responses == response_snapshot
|
||||
assert responses == response_snapshot
|
||||
|
@ -121,8 +121,6 @@ class MambaBlock(nn.Module):
|
||||
conv_out = causal_conv1d_fn( x=conv_state_new, weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), bias=self.conv1d.bias, activation=self.activation)
|
||||
conv_state = conv_state_new[:, :, 1:]
|
||||
|
||||
handle_batched = False
|
||||
if handle_batched:
|
||||
bsz, seqlen, dim = hidden_states.shape
|
||||
# empty output tensor for the loop
|
||||
output_tensor = torch.zeros(
|
||||
@ -132,8 +130,8 @@ class MambaBlock(nn.Module):
|
||||
)
|
||||
|
||||
for i in range(0, bsz):
|
||||
x = conv_out[:,:,i]
|
||||
z = _z[:, i, :]
|
||||
x = conv_out[i:i+1,:,-1]
|
||||
z = _z[i:i+1, -1, :]
|
||||
x_db = self.x_proj(x)
|
||||
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||
dt = self.dt_proj_no_bias(dt)
|
||||
@ -141,11 +139,10 @@ class MambaBlock(nn.Module):
|
||||
dA = torch.exp(dt * self.negA)
|
||||
dB = dt * B.view(-1, B.size(-1))
|
||||
x_shape = (-1, x.size(-1), 1)
|
||||
# ssm_state = (ssm_state * dA + dB * x.view(x_shape))
|
||||
ssm_state[i] = (ssm_state[i] * dA + dB * x)#.view(x_shape))
|
||||
ssm_state[i] = (ssm_state[i] * dA + dB * x.view(x_shape))
|
||||
c_shape = (C.size(0), C.size(1), -1)
|
||||
out_mm_shape = (C.size(0), -1)
|
||||
out = torch.matmul(ssm_state.to(C.dtype), C.view(c_shape)).view(out_mm_shape)
|
||||
out = torch.matmul(ssm_state[i].to(C.dtype), C.view(c_shape)).view(out_mm_shape)
|
||||
# in-place ops
|
||||
out.add_((x * self.D).to(out.dtype))
|
||||
out.mul_(F.silu(z))
|
||||
@ -154,38 +151,6 @@ class MambaBlock(nn.Module):
|
||||
|
||||
return output_tensor, conv_state, ssm_state
|
||||
|
||||
# TODO: remove this code only left for reference
|
||||
# # only support decoding with 1 token at a time
|
||||
# xz = self.in_proj(hidden_states.view((1, -1)))
|
||||
# x, z = xz.chunk(2, dim=-1) # (B D)
|
||||
# x = causal_conv1d_update(
|
||||
# x,
|
||||
# conv_state,
|
||||
# self.conv1d.weight.view(self.conv1d.weight.size(0), -1),
|
||||
# self.conv1d.bias,
|
||||
# self.activation,
|
||||
# )
|
||||
|
||||
# TODO: prefer using batched logic in all cases
|
||||
# this just pulls the last element of the batch
|
||||
x = conv_out[:,:,-1]
|
||||
z = _z[:, -1, :]
|
||||
x_db = self.x_proj(x)
|
||||
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||
dt = self.dt_proj_no_bias(dt)
|
||||
dt = F.softplus(dt + self.dt_proj.bias).view((dt.size(1), -1))
|
||||
dA = torch.exp(dt * self.negA)
|
||||
dB = dt * B.view(-1, B.size(-1))
|
||||
x_shape = (-1, x.size(-1), 1)
|
||||
ssm_state = (ssm_state * dA + dB * x.view(x_shape))
|
||||
c_shape = (C.size(0), C.size(1), -1)
|
||||
out_mm_shape = (C.size(0), -1)
|
||||
out = torch.matmul(ssm_state.to(C.dtype), C.view(c_shape)).view(out_mm_shape)
|
||||
# in-place ops
|
||||
out.add_((x * self.D).to(out.dtype))
|
||||
out.mul_(F.silu(z))
|
||||
out = self.out_proj(out)
|
||||
return out.view((1, -1, out.size(-1))), conv_state, ssm_state
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
@ -557,7 +557,7 @@ class Mamba(Model):
|
||||
top_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
batch.batch_id,
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
Tokens(
|
||||
[next_token_id_squeezed],
|
||||
|
Loading…
Reference in New Issue
Block a user