mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Speculative decoding + mistral
This commit is contained in:
parent
e7e07342bd
commit
7ed07bcc05
@ -57,22 +57,28 @@
|
|||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 313,
|
"id": 508,
|
||||||
"logprob": -1.0712891,
|
"logprob": -1.5087891,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " ("
|
"text": " can"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 15189,
|
"id": 367,
|
||||||
"logprob": -0.7578125,
|
|
||||||
"special": false,
|
|
||||||
"text": "also"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2998,
|
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " known"
|
"text": " be"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2714,
|
||||||
|
"logprob": -0.6538086,
|
||||||
|
"special": false,
|
||||||
|
"text": " thought"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 408,
|
"id": 408,
|
||||||
@ -81,18 +87,18 @@
|
|||||||
"text": " as"
|
"text": " as"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6483,
|
"id": 263,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " deep"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 19677,
|
"id": 11306,
|
||||||
"logprob": 0.0,
|
"logprob": -0.5488281,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " neural"
|
"text": " subset"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"generated_text": "What is Deep Learning?\nDeep learning (also known as deep neural"
|
"generated_text": "What is Deep Learning?\nDeep learning can be thought of as a subset"
|
||||||
}
|
}
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -1.5488281,
|
"logprob": -1.5498047,
|
||||||
"text": "is"
|
"text": "is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -27,12 +27,12 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29257,
|
"id": 29257,
|
||||||
"logprob": -1.2753906,
|
"logprob": -1.2734375,
|
||||||
"text": "Learning"
|
"text": "Learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29973,
|
"id": 29973,
|
||||||
"logprob": -0.48046875,
|
"logprob": -0.48217773,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -40,19 +40,19 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.1845703,
|
"logprob": -1.1875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2772,
|
"id": 2772,
|
||||||
"logprob": -0.5727539,
|
"logprob": -0.5708008,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "De"
|
"text": "De"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1022,
|
"id": 1022,
|
||||||
"logprob": -0.00010967255,
|
"logprob": -0.00010931492,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "ep"
|
"text": "ep"
|
||||||
},
|
},
|
||||||
@ -64,37 +64,37 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -0.04510498,
|
"logprob": -0.044433594,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 263,
|
"id": 263,
|
||||||
"logprob": -0.018295288,
|
"logprob": -0.018310547,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 11306,
|
"id": 11306,
|
||||||
"logprob": -0.45922852,
|
"logprob": -0.46044922,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " subset"
|
"text": " subset"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 310,
|
"id": 310,
|
||||||
"logprob": -0.00020992756,
|
"logprob": -0.0002104044,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " of"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4933,
|
"id": 4933,
|
||||||
"logprob": -0.0046539307,
|
"logprob": -0.004711151,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " machine"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6509,
|
"id": 6509,
|
||||||
"logprob": -0.00025844574,
|
"logprob": -0.00025820732,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
@ -112,7 +112,7 @@
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"generated_text": "ep learning is a subset of machine learning that involves"
|
"generated_text": "\nDeep learning is a subset of machine learning that involves"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
@ -127,12 +127,12 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1724,
|
"id": 1724,
|
||||||
"logprob": -10.734375,
|
"logprob": -10.7421875,
|
||||||
"text": "What"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -1.5488281,
|
"logprob": -1.5498047,
|
||||||
"text": "is"
|
"text": "is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -155,19 +155,19 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.1826172,
|
"logprob": -1.1835938,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2772,
|
"id": 2772,
|
||||||
"logprob": -0.56689453,
|
"logprob": -0.57470703,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "De"
|
"text": "De"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1022,
|
"id": 1022,
|
||||||
"logprob": -0.000108003616,
|
"logprob": -0.00010788441,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "ep"
|
"text": "ep"
|
||||||
},
|
},
|
||||||
@ -179,13 +179,13 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -0.044433594,
|
"logprob": -0.04510498,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 263,
|
"id": 263,
|
||||||
"logprob": -0.018295288,
|
"logprob": -0.018585205,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " a"
|
||||||
},
|
},
|
||||||
@ -197,19 +197,19 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 310,
|
"id": 310,
|
||||||
"logprob": -0.0002104044,
|
"logprob": -0.00021457672,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " of"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4933,
|
"id": 4933,
|
||||||
"logprob": -0.004711151,
|
"logprob": -0.004776001,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " machine"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6509,
|
"id": 6509,
|
||||||
"logprob": -0.00025892258,
|
"logprob": -0.0002593994,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
@ -227,7 +227,7 @@
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"generated_text": "ep learning is a subset of machine learning that involves"
|
"generated_text": "\nDeep learning is a subset of machine learning that involves"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
@ -242,12 +242,12 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1724,
|
"id": 1724,
|
||||||
"logprob": -10.734375,
|
"logprob": -10.7421875,
|
||||||
"text": "What"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -1.5488281,
|
"logprob": -1.5498047,
|
||||||
"text": "is"
|
"text": "is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -270,19 +270,19 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.1826172,
|
"logprob": -1.1835938,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2772,
|
"id": 2772,
|
||||||
"logprob": -0.56689453,
|
"logprob": -0.57470703,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "De"
|
"text": "De"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1022,
|
"id": 1022,
|
||||||
"logprob": -0.000108003616,
|
"logprob": -0.00010788441,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "ep"
|
"text": "ep"
|
||||||
},
|
},
|
||||||
@ -294,13 +294,13 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -0.044433594,
|
"logprob": -0.04510498,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 263,
|
"id": 263,
|
||||||
"logprob": -0.018295288,
|
"logprob": -0.018585205,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " a"
|
||||||
},
|
},
|
||||||
@ -312,19 +312,19 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 310,
|
"id": 310,
|
||||||
"logprob": -0.0002104044,
|
"logprob": -0.00021457672,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " of"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4933,
|
"id": 4933,
|
||||||
"logprob": -0.004711151,
|
"logprob": -0.004776001,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " machine"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6509,
|
"id": 6509,
|
||||||
"logprob": -0.00025892258,
|
"logprob": -0.0002593994,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
@ -342,7 +342,7 @@
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"generated_text": "ep learning is a subset of machine learning that involves"
|
"generated_text": "\nDeep learning is a subset of machine learning that involves"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
@ -357,12 +357,12 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1724,
|
"id": 1724,
|
||||||
"logprob": -10.734375,
|
"logprob": -10.7421875,
|
||||||
"text": "What"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -1.5488281,
|
"logprob": -1.5498047,
|
||||||
"text": "is"
|
"text": "is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -385,19 +385,19 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.1826172,
|
"logprob": -1.1835938,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2772,
|
"id": 2772,
|
||||||
"logprob": -0.56689453,
|
"logprob": -0.57470703,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "De"
|
"text": "De"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1022,
|
"id": 1022,
|
||||||
"logprob": -0.000108003616,
|
"logprob": -0.00010788441,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "ep"
|
"text": "ep"
|
||||||
},
|
},
|
||||||
@ -409,13 +409,13 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -0.044433594,
|
"logprob": -0.04510498,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 263,
|
"id": 263,
|
||||||
"logprob": -0.018295288,
|
"logprob": -0.018585205,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " a"
|
||||||
},
|
},
|
||||||
@ -427,19 +427,19 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 310,
|
"id": 310,
|
||||||
"logprob": -0.0002104044,
|
"logprob": -0.00021457672,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " of"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4933,
|
"id": 4933,
|
||||||
"logprob": -0.004711151,
|
"logprob": -0.004776001,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " machine"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6509,
|
"id": 6509,
|
||||||
"logprob": -0.00025892258,
|
"logprob": -0.0002593994,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
@ -457,6 +457,6 @@
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"generated_text": "ep learning is a subset of machine learning that involves"
|
"generated_text": "\nDeep learning is a subset of machine learning that involves"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -111,5 +111,5 @@
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"generated_text": "ep learning is a subset of machine learning that involves"
|
"generated_text": "\nDeep learning is a subset of machine learning that involves"
|
||||||
}
|
}
|
||||||
|
@ -43,7 +43,7 @@ async def test_flash_llama_all_params(flash_llama, response_snapshot):
|
|||||||
seed=0,
|
seed=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 5
|
assert response.details.generated_tokens == 10
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,6 +54,6 @@ async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot)
|
|||||||
|
|
||||||
assert len(responses) == 4
|
assert len(responses) == 4
|
||||||
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
|
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
|
||||||
assert responses[0].generated_text == 'ep learning is a subset of machine learning that involves'
|
assert responses[0].generated_text == '\nDeep learning is a subset of machine learning that involves'
|
||||||
|
|
||||||
assert responses == response_snapshot
|
assert responses == response_snapshot
|
||||||
|
@ -21,6 +21,7 @@ async def test_flash_mistral(flash_mistral, response_snapshot):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response.generated_text == ": Let n = 10 - 1"
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@ -55,6 +56,7 @@ async def test_flash_mistral_load(flash_mistral, generate_load, response_snapsho
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert len(responses) == 4
|
assert len(responses) == 4
|
||||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
|
||||||
|
assert responses[0].generated_text == ": Let n = 10 - 1"
|
||||||
|
|
||||||
assert responses == response_snapshot
|
assert responses == response_snapshot
|
||||||
|
@ -101,6 +101,8 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unknown dtype {dtype}")
|
raise RuntimeError(f"Unknown dtype {dtype}")
|
||||||
|
|
||||||
|
SPECULATE = 2
|
||||||
|
|
||||||
if "facebook/galactica" in model_id:
|
if "facebook/galactica" in model_id:
|
||||||
return GalacticaSharded(
|
return GalacticaSharded(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -479,19 +479,19 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_blocks = 0
|
max_blocks = 0
|
||||||
max_length = 0
|
max_length = 0
|
||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
speculative_length = 0 if batches[0].speculative_ids is None else batches[0].speculative_ids.shape[1]
|
|
||||||
for b in batches:
|
for b in batches:
|
||||||
total_batch_size += len(b)
|
total_batch_size += len(b)
|
||||||
total_slots += len(b.slots)
|
total_slots += len(b.slots)
|
||||||
blocks += b.blocks
|
blocks += b.blocks
|
||||||
|
speculative_length = 0 if b.speculative_ids is None else b.speculative_ids.shape[1]
|
||||||
max_blocks = max(max_blocks, b.max_blocks)
|
max_blocks = max(max_blocks, b.max_blocks)
|
||||||
max_seqlen = max(max_seqlen, b.max_seqlen)
|
max_seqlen = max(max_seqlen, b.max_seqlen)
|
||||||
max_length = max(
|
max_length = max(
|
||||||
max_length,
|
max_length,
|
||||||
max(
|
max(
|
||||||
input_length
|
input_length
|
||||||
+ speculative_length
|
|
||||||
+ stopping_criteria.max_new_tokens
|
+ stopping_criteria.max_new_tokens
|
||||||
|
+ speculative_length
|
||||||
- stopping_criteria.current_tokens
|
- stopping_criteria.current_tokens
|
||||||
for input_length, stopping_criteria in zip(
|
for input_length, stopping_criteria in zip(
|
||||||
b.input_lengths, b.stopping_criterias
|
b.input_lengths, b.stopping_criterias
|
||||||
@ -994,7 +994,8 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
|
|
||||||
for next_token_id in _next_token_ids:
|
left = 0
|
||||||
|
for j, next_token_id in enumerate(_next_token_ids):
|
||||||
stop, reason = stopping_criteria(
|
stop, reason = stopping_criteria(
|
||||||
next_token_id,
|
next_token_id,
|
||||||
next_token_text,
|
next_token_text,
|
||||||
@ -1002,21 +1003,26 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
stopped = True
|
stopped = True
|
||||||
|
left = len(_next_token_ids) - 1 - j
|
||||||
break
|
break
|
||||||
if not stop:
|
else:
|
||||||
stopped = False
|
stopped = False
|
||||||
|
_next_token_ids = _next_token_ids[:len(_next_token_ids) - left]
|
||||||
|
|
||||||
# Shard generations
|
# Shard generations
|
||||||
# All generations will be appended in the rust sharded client
|
# All generations will be appended in the rust sharded client
|
||||||
if i % self.world_size == self.rank:
|
if i % self.world_size == self.rank:
|
||||||
if stop:
|
if stop:
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
|
# Remove potentially accepted ids that do not respect
|
||||||
|
# the stopping_criteria
|
||||||
|
_ids = all_input_ids[:len(all_input_ids)-left]
|
||||||
output_text, _, _ = self.decode_token(
|
output_text, _, _ = self.decode_token(
|
||||||
all_input_ids,
|
_ids,
|
||||||
prefix_offset=len(all_input_ids)
|
prefix_offset=len(_ids)
|
||||||
- stopping_criteria.current_tokens
|
- stopping_criteria.current_tokens
|
||||||
- 1,
|
- 1,
|
||||||
read_offset=len(all_input_ids)
|
read_offset=len(_ids)
|
||||||
- stopping_criteria.current_tokens,
|
- stopping_criteria.current_tokens,
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
)
|
)
|
||||||
|
@ -132,7 +132,9 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
|
|
||||||
# Paged attention
|
# Paged attention
|
||||||
# Remove one as the first token des not have a past
|
# Remove one as the first token des not have a past
|
||||||
total_tokens = input_length + max_new_tokens - 1
|
from text_generation_server.models import SPECULATE
|
||||||
|
speculative_length = SPECULATE
|
||||||
|
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
||||||
|
|
||||||
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
|
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
|
||||||
needed_blocks = min(
|
needed_blocks = min(
|
||||||
@ -183,7 +185,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
cumulative_max_length += total_tokens
|
cumulative_max_length += total_tokens
|
||||||
max_seqlen = max(max_seqlen, input_length)
|
max_seqlen = max(max_seqlen, input_length)
|
||||||
max_blocks = max(max_blocks, needed_blocks)
|
max_blocks = max(max_blocks, needed_blocks)
|
||||||
max_length = max(max_length, input_length + max_new_tokens)
|
max_length = max(max_length, input_length + max_new_tokens + speculative_length)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters, dtype, device
|
next_token_chooser_parameters, dtype, device
|
||||||
@ -341,17 +343,55 @@ class FlashMistral(FlashCausalLM):
|
|||||||
|
|
||||||
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
|
if batch.speculative_ids is not None:
|
||||||
|
input_ids=batch.input_ids
|
||||||
|
position_ids=batch.position_ids
|
||||||
|
cu_seqlen_prefill=batch.cu_seqlen_prefill
|
||||||
|
kv_cache=get_cache_manager().kv_cache
|
||||||
|
block_tables=batch.block_tables_tensor
|
||||||
|
slots=batch.slots[batch.slot_indices]
|
||||||
|
input_lengths=batch.input_lengths_tensor
|
||||||
|
max_s=batch.max_seqlen
|
||||||
|
lm_head_indices=batch.prefill_head_indices
|
||||||
|
|
||||||
|
speculative_ids = batch.speculative_ids
|
||||||
|
|
||||||
|
B, speculative_length = speculative_ids.shape
|
||||||
|
new_length = speculative_length + 1
|
||||||
|
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1)
|
||||||
|
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
||||||
|
arange_int = arange.to(dtype=torch.int32)
|
||||||
|
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1)
|
||||||
|
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||||
|
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||||
|
|
||||||
|
# Add Copy the block tables for all members
|
||||||
|
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous()
|
||||||
|
max_s = max_s + speculative_length
|
||||||
|
|
||||||
|
input_ids = new_input_ids
|
||||||
|
position_ids = new_position_ids
|
||||||
|
else:
|
||||||
|
input_ids=batch.input_ids
|
||||||
|
position_ids=batch.position_ids
|
||||||
|
cu_seqlen_prefill=batch.cu_seqlen_prefill
|
||||||
|
kv_cache=get_cache_manager().kv_cache
|
||||||
|
block_tables=batch.block_tables_tensor
|
||||||
|
slots=batch.slots[batch.slot_indices]
|
||||||
|
input_lengths=batch.input_lengths_tensor
|
||||||
|
max_s=batch.max_seqlen
|
||||||
|
lm_head_indices=batch.prefill_head_indices
|
||||||
logits = self.model.forward(
|
logits = self.model.forward(
|
||||||
input_ids=batch.input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=batch.position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=batch.cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
kv_cache=get_cache_manager().kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=batch.block_tables_tensor,
|
block_tables=block_tables,
|
||||||
slots=batch.slots[batch.slot_indices],
|
slots=slots,
|
||||||
input_lengths=batch.input_lengths_tensor,
|
input_lengths=input_lengths,
|
||||||
max_s=batch.max_seqlen,
|
max_s=max_s,
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
lm_head_indices=batch.prefill_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
|
@ -258,16 +258,34 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None):
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None):
|
||||||
|
if speculated_ids is not None:
|
||||||
|
B = scores.shape[0] // (speculated_ids.shape[1] + 1)
|
||||||
|
S = speculated_ids.shape[1] + 1
|
||||||
|
scores = scores.view(B, S, -1)
|
||||||
|
else:
|
||||||
|
B = scores.shape[0]
|
||||||
|
S = 1
|
||||||
|
scores = scores.view(B, S, -1)
|
||||||
|
|
||||||
|
all_next_ids = []
|
||||||
|
all_scores = []
|
||||||
|
for j in range(S):
|
||||||
|
_scores = scores[:, j]
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
scores = self.watermark_processor(input_ids, scores)
|
_scores = self.watermark_processor(input_ids, _scores)
|
||||||
if self.repetition_processor is not None:
|
if self.repetition_processor is not None:
|
||||||
scores = self.repetition_processor(input_ids, scores)
|
_scores = self.repetition_processor(input_ids, _scores)
|
||||||
|
|
||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
scores = warper(input_ids, scores)
|
_scores = warper(input_ids, _scores)
|
||||||
|
|
||||||
|
|
||||||
next_ids = self.choice(scores)
|
next_ids = self.choice(_scores)
|
||||||
|
scores[:, j] = _scores
|
||||||
|
all_next_ids.append(next_ids.unsqueeze(1))
|
||||||
|
next_ids = torch.cat(all_next_ids, dim=1).reshape(B*S)
|
||||||
|
scores = scores.view( B* S, -1)
|
||||||
|
|
||||||
if speculated_ids is not None:
|
if speculated_ids is not None:
|
||||||
accepted_ids = []
|
accepted_ids = []
|
||||||
B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
|
B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
|
||||||
@ -289,6 +307,9 @@ class HeterogeneousNextTokenChooser:
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
accepted_ids.append(accepted)
|
accepted_ids.append(accepted)
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
logger.info(f"ACCEPTED IDS {accepted_ids}")
|
||||||
accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype)
|
accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype)
|
||||||
next_ids = next_ids[indices]
|
next_ids = next_ids[indices]
|
||||||
scores = scores[indices]
|
scores = scores[indices]
|
||||||
@ -298,7 +319,6 @@ class HeterogeneousNextTokenChooser:
|
|||||||
else:
|
else:
|
||||||
accepted_ids = torch.ones_like(next_ids)
|
accepted_ids = torch.ones_like(next_ids)
|
||||||
|
|
||||||
|
|
||||||
logprobs = torch.log_softmax(scores, -1)
|
logprobs = torch.log_softmax(scores, -1)
|
||||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user