mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
feat: support batching
This commit is contained in:
parent
63bc4c59d4
commit
3caa9b9cb7
@ -0,0 +1,338 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 5089,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2748,
|
||||||
|
"logprob": -9.7265625,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -2.4746094,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -1.3857422,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 510,
|
||||||
|
"logprob": -2.703125,
|
||||||
|
"special": false,
|
||||||
|
"text": "The"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 806,
|
||||||
|
"logprob": -4.1992188,
|
||||||
|
"special": false,
|
||||||
|
"text": " first"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2181,
|
||||||
|
"logprob": -2.703125,
|
||||||
|
"special": false,
|
||||||
|
"text": " thing"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 309,
|
||||||
|
"logprob": -1.4160156,
|
||||||
|
"special": false,
|
||||||
|
"text": " I"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8344,
|
||||||
|
"logprob": -1.6171875,
|
||||||
|
"special": false,
|
||||||
|
"text": " noticed"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 369,
|
||||||
|
"logprob": -1.0039062,
|
||||||
|
"special": false,
|
||||||
|
"text": " was"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 326,
|
||||||
|
"logprob": -0.8823242,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 253,
|
||||||
|
"logprob": -1.3173828,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\nThe first thing I noticed was that the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 5089,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2748,
|
||||||
|
"logprob": -9.7265625,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -2.4941406,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -1.3857422,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 510,
|
||||||
|
"logprob": -2.703125,
|
||||||
|
"special": false,
|
||||||
|
"text": "The"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 806,
|
||||||
|
"logprob": -4.1992188,
|
||||||
|
"special": false,
|
||||||
|
"text": " first"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2181,
|
||||||
|
"logprob": -2.703125,
|
||||||
|
"special": false,
|
||||||
|
"text": " thing"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 309,
|
||||||
|
"logprob": -1.4160156,
|
||||||
|
"special": false,
|
||||||
|
"text": " I"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8344,
|
||||||
|
"logprob": -1.6171875,
|
||||||
|
"special": false,
|
||||||
|
"text": " noticed"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 369,
|
||||||
|
"logprob": -1.0039062,
|
||||||
|
"special": false,
|
||||||
|
"text": " was"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 326,
|
||||||
|
"logprob": -0.8823242,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 253,
|
||||||
|
"logprob": -1.3173828,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\nThe first thing I noticed was that the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 5089,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2748,
|
||||||
|
"logprob": -9.7265625,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -2.4941406,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -1.3857422,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 510,
|
||||||
|
"logprob": -2.703125,
|
||||||
|
"special": false,
|
||||||
|
"text": "The"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 806,
|
||||||
|
"logprob": -4.1992188,
|
||||||
|
"special": false,
|
||||||
|
"text": " first"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2181,
|
||||||
|
"logprob": -2.703125,
|
||||||
|
"special": false,
|
||||||
|
"text": " thing"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 309,
|
||||||
|
"logprob": -1.4160156,
|
||||||
|
"special": false,
|
||||||
|
"text": " I"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8344,
|
||||||
|
"logprob": -1.6171875,
|
||||||
|
"special": false,
|
||||||
|
"text": " noticed"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 369,
|
||||||
|
"logprob": -1.0039062,
|
||||||
|
"special": false,
|
||||||
|
"text": " was"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 326,
|
||||||
|
"logprob": -0.8823242,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 253,
|
||||||
|
"logprob": -1.3173828,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\nThe first thing I noticed was that the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 5089,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2748,
|
||||||
|
"logprob": -9.7265625,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -2.4941406,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -1.3857422,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 510,
|
||||||
|
"logprob": -2.703125,
|
||||||
|
"special": false,
|
||||||
|
"text": "The"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 806,
|
||||||
|
"logprob": -4.1992188,
|
||||||
|
"special": false,
|
||||||
|
"text": " first"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2181,
|
||||||
|
"logprob": -2.703125,
|
||||||
|
"special": false,
|
||||||
|
"text": " thing"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 309,
|
||||||
|
"logprob": -1.4160156,
|
||||||
|
"special": false,
|
||||||
|
"text": " I"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8344,
|
||||||
|
"logprob": -1.6171875,
|
||||||
|
"special": false,
|
||||||
|
"text": " noticed"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 369,
|
||||||
|
"logprob": -1.0039062,
|
||||||
|
"special": false,
|
||||||
|
"text": " was"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 326,
|
||||||
|
"logprob": -0.8823242,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 253,
|
||||||
|
"logprob": -1.3173828,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\nThe first thing I noticed was that the"
|
||||||
|
}
|
||||||
|
]
|
@ -47,13 +47,12 @@ async def test_fused_kernel_mamba_all_params(fused_kernel_mamba, response_snapsh
|
|||||||
assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in"
|
assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in"
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
# TODO: Fix batching
|
@pytest.mark.asyncio
|
||||||
# @pytest.mark.asyncio
|
@pytest.mark.private
|
||||||
# @pytest.mark.private
|
async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
|
||||||
# async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
|
responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4)
|
||||||
# responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=2)
|
|
||||||
|
|
||||||
# assert len(responses) == 4
|
assert len(responses) == 4
|
||||||
# assert all([r.generated_text == responses[0].generated_text for r in responses])
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
# assert responses == response_snapshot
|
assert responses == response_snapshot
|
||||||
|
@ -121,71 +121,36 @@ class MambaBlock(nn.Module):
|
|||||||
conv_out = causal_conv1d_fn( x=conv_state_new, weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), bias=self.conv1d.bias, activation=self.activation)
|
conv_out = causal_conv1d_fn( x=conv_state_new, weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), bias=self.conv1d.bias, activation=self.activation)
|
||||||
conv_state = conv_state_new[:, :, 1:]
|
conv_state = conv_state_new[:, :, 1:]
|
||||||
|
|
||||||
handle_batched = False
|
bsz, seqlen, dim = hidden_states.shape
|
||||||
if handle_batched:
|
# empty output tensor for the loop
|
||||||
bsz, seqlen, dim = hidden_states.shape
|
output_tensor = torch.zeros(
|
||||||
# empty output tensor for the loop
|
(bsz, seqlen, dim),
|
||||||
output_tensor = torch.zeros(
|
device=hidden_states.device,
|
||||||
(bsz, seqlen, dim),
|
dtype=hidden_states.dtype
|
||||||
device=hidden_states.device,
|
)
|
||||||
dtype=hidden_states.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in range(0, bsz):
|
for i in range(0, bsz):
|
||||||
x = conv_out[:,:,i]
|
x = conv_out[i:i+1,:,-1]
|
||||||
z = _z[:, i, :]
|
z = _z[i:i+1, -1, :]
|
||||||
x_db = self.x_proj(x)
|
x_db = self.x_proj(x)
|
||||||
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||||
dt = self.dt_proj_no_bias(dt)
|
dt = self.dt_proj_no_bias(dt)
|
||||||
dt = F.softplus(dt + self.dt_proj.bias).view((dt.size(1), -1))
|
dt = F.softplus(dt + self.dt_proj.bias).view((dt.size(1), -1))
|
||||||
dA = torch.exp(dt * self.negA)
|
dA = torch.exp(dt * self.negA)
|
||||||
dB = dt * B.view(-1, B.size(-1))
|
dB = dt * B.view(-1, B.size(-1))
|
||||||
x_shape = (-1, x.size(-1), 1)
|
x_shape = (-1, x.size(-1), 1)
|
||||||
# ssm_state = (ssm_state * dA + dB * x.view(x_shape))
|
ssm_state[i] = (ssm_state[i] * dA + dB * x.view(x_shape))
|
||||||
ssm_state[i] = (ssm_state[i] * dA + dB * x)#.view(x_shape))
|
c_shape = (C.size(0), C.size(1), -1)
|
||||||
c_shape = (C.size(0), C.size(1), -1)
|
out_mm_shape = (C.size(0), -1)
|
||||||
out_mm_shape = (C.size(0), -1)
|
out = torch.matmul(ssm_state[i].to(C.dtype), C.view(c_shape)).view(out_mm_shape)
|
||||||
out = torch.matmul(ssm_state.to(C.dtype), C.view(c_shape)).view(out_mm_shape)
|
# in-place ops
|
||||||
# in-place ops
|
out.add_((x * self.D).to(out.dtype))
|
||||||
out.add_((x * self.D).to(out.dtype))
|
out.mul_(F.silu(z))
|
||||||
out.mul_(F.silu(z))
|
out = self.out_proj(out)
|
||||||
out = self.out_proj(out)
|
output_tensor[i] = out
|
||||||
output_tensor[i] = out
|
|
||||||
|
|
||||||
return output_tensor, conv_state, ssm_state
|
return output_tensor, conv_state, ssm_state
|
||||||
|
|
||||||
# TODO: remove this code only left for reference
|
|
||||||
# # only support decoding with 1 token at a time
|
|
||||||
# xz = self.in_proj(hidden_states.view((1, -1)))
|
|
||||||
# x, z = xz.chunk(2, dim=-1) # (B D)
|
|
||||||
# x = causal_conv1d_update(
|
|
||||||
# x,
|
|
||||||
# conv_state,
|
|
||||||
# self.conv1d.weight.view(self.conv1d.weight.size(0), -1),
|
|
||||||
# self.conv1d.bias,
|
|
||||||
# self.activation,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# TODO: prefer using batched logic in all cases
|
|
||||||
# this just pulls the last element of the batch
|
|
||||||
x = conv_out[:,:,-1]
|
|
||||||
z = _z[:, -1, :]
|
|
||||||
x_db = self.x_proj(x)
|
|
||||||
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
|
||||||
dt = self.dt_proj_no_bias(dt)
|
|
||||||
dt = F.softplus(dt + self.dt_proj.bias).view((dt.size(1), -1))
|
|
||||||
dA = torch.exp(dt * self.negA)
|
|
||||||
dB = dt * B.view(-1, B.size(-1))
|
|
||||||
x_shape = (-1, x.size(-1), 1)
|
|
||||||
ssm_state = (ssm_state * dA + dB * x.view(x_shape))
|
|
||||||
c_shape = (C.size(0), C.size(1), -1)
|
|
||||||
out_mm_shape = (C.size(0), -1)
|
|
||||||
out = torch.matmul(ssm_state.to(C.dtype), C.view(c_shape)).view(out_mm_shape)
|
|
||||||
# in-place ops
|
|
||||||
out.add_((x * self.D).to(out.dtype))
|
|
||||||
out.mul_(F.silu(z))
|
|
||||||
out = self.out_proj(out)
|
|
||||||
return out.view((1, -1, out.size(-1))), conv_state, ssm_state
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
|
@ -557,7 +557,7 @@ class Mamba(Model):
|
|||||||
top_tokens = None
|
top_tokens = None
|
||||||
|
|
||||||
generation = Generation(
|
generation = Generation(
|
||||||
batch.batch_id,
|
request.id,
|
||||||
prefill_tokens,
|
prefill_tokens,
|
||||||
Tokens(
|
Tokens(
|
||||||
[next_token_id_squeezed],
|
[next_token_id_squeezed],
|
||||||
|
Loading…
Reference in New Issue
Block a user