mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fix mamba load.
This commit is contained in:
parent
3caa9b9cb7
commit
8319e854c8
@ -6,82 +6,97 @@
|
|||||||
"generated_tokens": 10,
|
"generated_tokens": 10,
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 5089,
|
"id": 1276,
|
||||||
"logprob": null,
|
"logprob": null,
|
||||||
"text": "Test"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2748,
|
"id": 310,
|
||||||
"logprob": -9.7265625,
|
"logprob": -0.8125,
|
||||||
"text": " request"
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18147,
|
||||||
|
"logprob": -12.828125,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -3.0,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 32,
|
||||||
|
"logprob": -1.1484375,
|
||||||
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 187,
|
"id": 187,
|
||||||
"logprob": -2.4746094,
|
"logprob": -0.3552246,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 187,
|
"id": 187,
|
||||||
"logprob": -1.3857422,
|
"logprob": -0.38989258,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 510,
|
"id": 30763,
|
||||||
"logprob": -2.703125,
|
"logprob": -1.1386719,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "The"
|
"text": "Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 806,
|
"id": 4715,
|
||||||
"logprob": -4.1992188,
|
"logprob": -0.5576172,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " first"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2181,
|
"id": 310,
|
||||||
"logprob": -2.703125,
|
"logprob": -0.5913086,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " thing"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 309,
|
"id": 247,
|
||||||
"logprob": -1.4160156,
|
"logprob": -0.69970703,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " I"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 8344,
|
"id": 747,
|
||||||
"logprob": -1.6171875,
|
"logprob": -2.0449219,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " noticed"
|
"text": " new"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 369,
|
"id": 1511,
|
||||||
"logprob": -1.0039062,
|
"logprob": -2.3847656,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " was"
|
"text": " type"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 326,
|
"id": 273,
|
||||||
"logprob": -0.8823242,
|
"logprob": -0.0026626587,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " that"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 253,
|
"id": 5145,
|
||||||
"logprob": -1.3173828,
|
"logprob": -1.2841797,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " the"
|
"text": " machine"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n\nThe first thing I noticed was that the"
|
"generated_text": "\n\nDeep learning is a new type of machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
@ -90,82 +105,97 @@
|
|||||||
"generated_tokens": 10,
|
"generated_tokens": 10,
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 5089,
|
"id": 1276,
|
||||||
"logprob": null,
|
"logprob": null,
|
||||||
"text": "Test"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2748,
|
"id": 310,
|
||||||
"logprob": -9.7265625,
|
"logprob": -0.78027344,
|
||||||
"text": " request"
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18147,
|
||||||
|
"logprob": -12.8203125,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -2.9902344,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 32,
|
||||||
|
"logprob": -1.1523438,
|
||||||
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 187,
|
"id": 187,
|
||||||
"logprob": -2.4941406,
|
"logprob": -0.35351562,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 187,
|
"id": 187,
|
||||||
"logprob": -1.3857422,
|
"logprob": -0.38476562,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 510,
|
"id": 30763,
|
||||||
"logprob": -2.703125,
|
"logprob": -1.1308594,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "The"
|
"text": "Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 806,
|
"id": 4715,
|
||||||
"logprob": -4.1992188,
|
"logprob": -0.5522461,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " first"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2181,
|
"id": 310,
|
||||||
"logprob": -2.703125,
|
"logprob": -0.59375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " thing"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 309,
|
"id": 247,
|
||||||
"logprob": -1.4160156,
|
"logprob": -0.7036133,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " I"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 8344,
|
"id": 747,
|
||||||
"logprob": -1.6171875,
|
"logprob": -2.0507812,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " noticed"
|
"text": " new"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 369,
|
"id": 1511,
|
||||||
"logprob": -1.0039062,
|
"logprob": -2.3808594,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " was"
|
"text": " type"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 326,
|
"id": 273,
|
||||||
"logprob": -0.8823242,
|
"logprob": -0.002664566,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " that"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 253,
|
"id": 5145,
|
||||||
"logprob": -1.3173828,
|
"logprob": -1.2851562,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " the"
|
"text": " machine"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n\nThe first thing I noticed was that the"
|
"generated_text": "\n\nDeep learning is a new type of machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
@ -174,82 +204,97 @@
|
|||||||
"generated_tokens": 10,
|
"generated_tokens": 10,
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 5089,
|
"id": 1276,
|
||||||
"logprob": null,
|
"logprob": null,
|
||||||
"text": "Test"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2748,
|
"id": 310,
|
||||||
"logprob": -9.7265625,
|
"logprob": -0.78027344,
|
||||||
"text": " request"
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18147,
|
||||||
|
"logprob": -12.8203125,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -2.9902344,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 32,
|
||||||
|
"logprob": -1.1523438,
|
||||||
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 187,
|
"id": 187,
|
||||||
"logprob": -2.4941406,
|
"logprob": -0.35351562,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 187,
|
"id": 187,
|
||||||
"logprob": -1.3857422,
|
"logprob": -0.38476562,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 510,
|
"id": 30763,
|
||||||
"logprob": -2.703125,
|
"logprob": -1.1308594,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "The"
|
"text": "Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 806,
|
"id": 4715,
|
||||||
"logprob": -4.1992188,
|
"logprob": -0.5522461,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " first"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2181,
|
"id": 310,
|
||||||
"logprob": -2.703125,
|
"logprob": -0.59375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " thing"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 309,
|
"id": 247,
|
||||||
"logprob": -1.4160156,
|
"logprob": -0.7036133,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " I"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 8344,
|
"id": 747,
|
||||||
"logprob": -1.6171875,
|
"logprob": -2.0507812,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " noticed"
|
"text": " new"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 369,
|
"id": 1511,
|
||||||
"logprob": -1.0039062,
|
"logprob": -2.3808594,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " was"
|
"text": " type"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 326,
|
"id": 273,
|
||||||
"logprob": -0.8823242,
|
"logprob": -0.002664566,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " that"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 253,
|
"id": 5145,
|
||||||
"logprob": -1.3173828,
|
"logprob": -1.2851562,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " the"
|
"text": " machine"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n\nThe first thing I noticed was that the"
|
"generated_text": "\n\nDeep learning is a new type of machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
@ -258,81 +303,96 @@
|
|||||||
"generated_tokens": 10,
|
"generated_tokens": 10,
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 5089,
|
"id": 1276,
|
||||||
"logprob": null,
|
"logprob": null,
|
||||||
"text": "Test"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2748,
|
"id": 310,
|
||||||
"logprob": -9.7265625,
|
"logprob": -0.78027344,
|
||||||
"text": " request"
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18147,
|
||||||
|
"logprob": -12.8203125,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -2.9902344,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 32,
|
||||||
|
"logprob": -1.1523438,
|
||||||
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 187,
|
"id": 187,
|
||||||
"logprob": -2.4941406,
|
"logprob": -0.35351562,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 187,
|
"id": 187,
|
||||||
"logprob": -1.3857422,
|
"logprob": -0.38476562,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 510,
|
"id": 30763,
|
||||||
"logprob": -2.703125,
|
"logprob": -1.1308594,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "The"
|
"text": "Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 806,
|
"id": 4715,
|
||||||
"logprob": -4.1992188,
|
"logprob": -0.5522461,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " first"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2181,
|
"id": 310,
|
||||||
"logprob": -2.703125,
|
"logprob": -0.59375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " thing"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 309,
|
"id": 247,
|
||||||
"logprob": -1.4160156,
|
"logprob": -0.7036133,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " I"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 8344,
|
"id": 747,
|
||||||
"logprob": -1.6171875,
|
"logprob": -2.0507812,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " noticed"
|
"text": " new"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 369,
|
"id": 1511,
|
||||||
"logprob": -1.0039062,
|
"logprob": -2.3808594,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " was"
|
"text": " type"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 326,
|
"id": 273,
|
||||||
"logprob": -0.8823242,
|
"logprob": -0.002664566,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " that"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 253,
|
"id": 5145,
|
||||||
"logprob": -1.3173828,
|
"logprob": -1.2851562,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " the"
|
"text": " machine"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n\nThe first thing I noticed was that the"
|
"generated_text": "\n\nDeep learning is a new type of machine"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -17,10 +17,11 @@ async def fused_kernel_mamba(fused_kernel_mamba_handle):
|
|||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_fused_kernel_mamba(fused_kernel_mamba, response_snapshot):
|
async def test_fused_kernel_mamba(fused_kernel_mamba, response_snapshot):
|
||||||
response = await fused_kernel_mamba.generate(
|
response = await fused_kernel_mamba.generate(
|
||||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
"What is Deep Learning?", max_new_tokens=10
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response.generated_text == "\n\nDeep learning is a new type of machine"
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@ -50,9 +51,10 @@ async def test_fused_kernel_mamba_all_params(fused_kernel_mamba, response_snapsh
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
|
async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
|
||||||
responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4)
|
responses = await generate_load(fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4)
|
||||||
|
|
||||||
assert len(responses) == 4
|
assert len(responses) == 4
|
||||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
assert responses[0].generated_text == "\n\nDeep learning is a new type of machine"
|
||||||
|
|
||||||
assert responses == response_snapshot
|
assert responses == response_snapshot
|
||||||
|
@ -169,6 +169,7 @@ class MambaBatch(Batch):
|
|||||||
total_remaining_decode_tokens = 0
|
total_remaining_decode_tokens = 0
|
||||||
new_padding_right_offset = 0
|
new_padding_right_offset = 0
|
||||||
|
|
||||||
|
indices = []
|
||||||
for i, request_id in enumerate(request_ids):
|
for i, request_id in enumerate(request_ids):
|
||||||
idx = self.requests_idx_mapping[request_id]
|
idx = self.requests_idx_mapping[request_id]
|
||||||
requests_idx_mapping[request_id] = i
|
requests_idx_mapping[request_id] = i
|
||||||
@ -182,6 +183,7 @@ class MambaBatch(Batch):
|
|||||||
request_input_length = self.input_lengths[idx]
|
request_input_length = self.input_lengths[idx]
|
||||||
input_lengths.append(request_input_length)
|
input_lengths.append(request_input_length)
|
||||||
max_input_length = max(max_input_length, request_input_length)
|
max_input_length = max(max_input_length, request_input_length)
|
||||||
|
indices.append(idx)
|
||||||
|
|
||||||
next_token_choosers.append(self.next_token_choosers[idx])
|
next_token_choosers.append(self.next_token_choosers[idx])
|
||||||
stopping_criteria = self.stopping_criterias[idx]
|
stopping_criteria = self.stopping_criterias[idx]
|
||||||
@ -216,6 +218,13 @@ class MambaBatch(Batch):
|
|||||||
self.padding_right_offset = new_padding_right_offset
|
self.padding_right_offset = new_padding_right_offset
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
|
||||||
|
key_value_memory_dict = {}
|
||||||
|
for i, (conv_state, ssm_state) in self.inference_params.key_value_memory_dict.items():
|
||||||
|
key_value_memory_dict[i] = (conv_state[indices], ssm_state[indices])
|
||||||
|
self.inference_params.key_value_memory_dict = key_value_memory_dict
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -240,6 +249,9 @@ class MambaBatch(Batch):
|
|||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
top_n_tokens = []
|
top_n_tokens = []
|
||||||
max_tokens = 0
|
max_tokens = 0
|
||||||
|
max_seqlen = 0
|
||||||
|
batch_size = 0
|
||||||
|
seqlen_offset = 0
|
||||||
|
|
||||||
# Batch tensors
|
# Batch tensors
|
||||||
input_ids = None
|
input_ids = None
|
||||||
@ -287,8 +299,60 @@ class MambaBatch(Batch):
|
|||||||
max_input_length - batch.max_input_length
|
max_input_length - batch.max_input_length
|
||||||
) * len(batch)
|
) * len(batch)
|
||||||
|
|
||||||
|
max_seqlen = max(max_seqlen, batch.inference_params.max_seqlen)
|
||||||
|
seqlen_offset = max(seqlen_offset, batch.inference_params.seqlen_offset)
|
||||||
|
batch_size += batch.inference_params.max_batch_size
|
||||||
|
|
||||||
start_index = end_index
|
start_index = end_index
|
||||||
|
|
||||||
|
|
||||||
|
(_, d_model, d_conv) = batches[0].inference_params.key_value_memory_dict[0][0].shape
|
||||||
|
(_, _, d_state) = batches[0].inference_params.key_value_memory_dict[0][1].shape
|
||||||
|
n_blocks = len(batches[0].inference_params.key_value_memory_dict)
|
||||||
|
dtype = batches[0].inference_params.key_value_memory_dict[0][0].dtype
|
||||||
|
device = batches[0].inference_params.key_value_memory_dict[0][0].device
|
||||||
|
|
||||||
|
key_value_memory_dict = {}
|
||||||
|
for i in range(n_blocks):
|
||||||
|
conv_state = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
d_model,
|
||||||
|
d_conv,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
ssm_state = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
d_model,
|
||||||
|
d_state,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
key_value_memory_dict[i] = (conv_state, ssm_state)
|
||||||
|
lengths_per_sample = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
inference_params = InferenceParams(
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
max_batch_size=batch_size,
|
||||||
|
seqlen_offset=seqlen_offset,
|
||||||
|
key_value_memory_dict=key_value_memory_dict,
|
||||||
|
lengths_per_sample=lengths_per_sample,
|
||||||
|
)
|
||||||
|
|
||||||
|
current_batch = 0
|
||||||
|
for batch in batches:
|
||||||
|
for i in range(n_blocks):
|
||||||
|
conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i]
|
||||||
|
batch_size = batch.inference_params.max_batch_size
|
||||||
|
try:
|
||||||
|
inference_params.key_value_memory_dict[i][0][current_batch:current_batch + batch_size] = conv_state
|
||||||
|
inference_params.key_value_memory_dict[i][1][current_batch:current_batch + batch_size] = ssm_state
|
||||||
|
except Exception:
|
||||||
|
import ipdb;ipdb.set_trace()
|
||||||
|
pass
|
||||||
|
inference_params.lengths_per_sample[current_batch: current_batch + batch_size] = batch.inference_params.lengths_per_sample
|
||||||
|
current_batch += batch_size
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=batches[0].batch_id,
|
batch_id=batches[0].batch_id,
|
||||||
requests=requests,
|
requests=requests,
|
||||||
@ -306,6 +370,7 @@ class MambaBatch(Batch):
|
|||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
inference_params=inference_params
|
||||||
)
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@ -380,7 +445,6 @@ class Mamba(Model):
|
|||||||
|
|
||||||
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
||||||
start = time.time_ns()
|
start = time.time_ns()
|
||||||
|
|
||||||
input_ids = batch.input_ids # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
|
input_ids = batch.input_ids # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
|
||||||
|
|
||||||
batch_size = input_ids.shape[0]
|
batch_size = input_ids.shape[0]
|
||||||
|
Loading…
Reference in New Issue
Block a user