mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Medusa + ngram
This commit is contained in:
parent
b4d97d52cd
commit
657ccd8276
@ -0,0 +1,58 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "stop_sequence",
|
||||||
|
"generated_tokens": 5,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4321,
|
||||||
|
"logprob": -10.0625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -12.28125,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 5229,
|
||||||
|
"logprob": -1.7587891,
|
||||||
|
"special": false,
|
||||||
|
"text": " failed"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 363,
|
||||||
|
"logprob": -0.5175781,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1404,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 376,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1688,
|
||||||
|
"logprob": -0.20422363,
|
||||||
|
"special": false,
|
||||||
|
"text": "test"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": "Test request failed for user \"test"
|
||||||
|
}
|
@ -0,0 +1,354 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4321,
|
||||||
|
"logprob": -10.0625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -12.28125,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 363,
|
||||||
|
"logprob": -2.0878906,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 278,
|
||||||
|
"logprob": -3.4082031,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 376,
|
||||||
|
"logprob": -3.8457031,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2577,
|
||||||
|
"logprob": -3.5605469,
|
||||||
|
"special": false,
|
||||||
|
"text": "Get"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 599,
|
||||||
|
"logprob": -3.4707031,
|
||||||
|
"special": false,
|
||||||
|
"text": " all"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4160,
|
||||||
|
"logprob": -3.2421875,
|
||||||
|
"special": false,
|
||||||
|
"text": " users"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29908,
|
||||||
|
"logprob": -0.49072266,
|
||||||
|
"special": false,
|
||||||
|
"text": "\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16248,
|
||||||
|
"logprob": -1.2353516,
|
||||||
|
"special": false,
|
||||||
|
"text": " endpoint"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": -0.8833008,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.42089844,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": " for the \"Get all users\" endpoint.\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4321,
|
||||||
|
"logprob": -10.0625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -12.28125,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 363,
|
||||||
|
"logprob": -2.0878906,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 278,
|
||||||
|
"logprob": -3.4082031,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 376,
|
||||||
|
"logprob": -3.8457031,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2577,
|
||||||
|
"logprob": -3.5625,
|
||||||
|
"special": false,
|
||||||
|
"text": "Get"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 599,
|
||||||
|
"logprob": -3.4726562,
|
||||||
|
"special": false,
|
||||||
|
"text": " all"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4160,
|
||||||
|
"logprob": -3.2382812,
|
||||||
|
"special": false,
|
||||||
|
"text": " users"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29908,
|
||||||
|
"logprob": -0.49047852,
|
||||||
|
"special": false,
|
||||||
|
"text": "\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16248,
|
||||||
|
"logprob": -1.2412109,
|
||||||
|
"special": false,
|
||||||
|
"text": " endpoint"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": -0.87402344,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.41723633,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": " for the \"Get all users\" endpoint.\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4321,
|
||||||
|
"logprob": -10.0625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -12.28125,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 363,
|
||||||
|
"logprob": -2.0878906,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 278,
|
||||||
|
"logprob": -3.4082031,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 376,
|
||||||
|
"logprob": -3.8457031,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2577,
|
||||||
|
"logprob": -3.5605469,
|
||||||
|
"special": false,
|
||||||
|
"text": "Get"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 599,
|
||||||
|
"logprob": -3.4707031,
|
||||||
|
"special": false,
|
||||||
|
"text": " all"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4160,
|
||||||
|
"logprob": -3.2421875,
|
||||||
|
"special": false,
|
||||||
|
"text": " users"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29908,
|
||||||
|
"logprob": -0.49072266,
|
||||||
|
"special": false,
|
||||||
|
"text": "\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16248,
|
||||||
|
"logprob": -1.2353516,
|
||||||
|
"special": false,
|
||||||
|
"text": " endpoint"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": -0.8833008,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.42089844,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": " for the \"Get all users\" endpoint.\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4321,
|
||||||
|
"logprob": -10.0625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -12.28125,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 363,
|
||||||
|
"logprob": -2.0878906,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 278,
|
||||||
|
"logprob": -3.4082031,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 376,
|
||||||
|
"logprob": -3.8457031,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2577,
|
||||||
|
"logprob": -3.5605469,
|
||||||
|
"special": false,
|
||||||
|
"text": "Get"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 599,
|
||||||
|
"logprob": -3.4707031,
|
||||||
|
"special": false,
|
||||||
|
"text": " all"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4160,
|
||||||
|
"logprob": -3.2421875,
|
||||||
|
"special": false,
|
||||||
|
"text": " users"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29908,
|
||||||
|
"logprob": -0.49072266,
|
||||||
|
"special": false,
|
||||||
|
"text": "\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16248,
|
||||||
|
"logprob": -1.2353516,
|
||||||
|
"special": false,
|
||||||
|
"text": " endpoint"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": -0.8833008,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.42089844,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": " for the \"Get all users\" endpoint.\n"
|
||||||
|
}
|
||||||
|
]
|
@ -0,0 +1,88 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4321,
|
||||||
|
"logprob": -10.0625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -12.28125,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 363,
|
||||||
|
"logprob": -2.0878906,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 278,
|
||||||
|
"logprob": -3.4121094,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 376,
|
||||||
|
"logprob": -3.8457031,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2577,
|
||||||
|
"logprob": -3.5566406,
|
||||||
|
"special": false,
|
||||||
|
"text": "Get"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 599,
|
||||||
|
"logprob": -3.4746094,
|
||||||
|
"special": false,
|
||||||
|
"text": " all"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4160,
|
||||||
|
"logprob": -3.2363281,
|
||||||
|
"special": false,
|
||||||
|
"text": " users"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29908,
|
||||||
|
"logprob": -0.49023438,
|
||||||
|
"special": false,
|
||||||
|
"text": "\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16248,
|
||||||
|
"logprob": -1.2402344,
|
||||||
|
"special": false,
|
||||||
|
"text": " endpoint"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": -0.88134766,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.41870117,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": " for the \"Get all users\" endpoint.\n"
|
||||||
|
}
|
59
integration-tests/models/test_flash_medusa.py
Normal file
59
integration-tests/models/test_flash_medusa.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_medusa_handle(launcher):
|
||||||
|
with launcher("FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_medusa(flash_medusa_handle):
|
||||||
|
await flash_medusa_handle.health(300)
|
||||||
|
return flash_medusa_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_medusa_simple(flash_medusa, response_snapshot):
|
||||||
|
response = await flash_medusa.generate(
|
||||||
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
|
||||||
|
response = await flash_medusa.generate(
|
||||||
|
"Test request",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 5
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):
|
||||||
|
responses = await generate_load(flash_medusa, "Test request", max_new_tokens=10, n=4)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
|
||||||
|
assert responses[0].generated_text == ' for the "Get all users" endpoint.\n'
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
@ -155,6 +155,14 @@ struct Args {
|
|||||||
#[clap(long, env, value_enum)]
|
#[clap(long, env, value_enum)]
|
||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
|
|
||||||
|
/// The number of input_ids to speculate on
|
||||||
|
/// If using a medusa model, the heads will be picked up automatically
|
||||||
|
/// Other wise, it will use n-gram speculation which is relatively free
|
||||||
|
/// in terms of compute, but the speedup heavily depends on the task.
|
||||||
|
#[clap(long, env)]
|
||||||
|
speculate: Option<usize>,
|
||||||
|
|
||||||
|
|
||||||
/// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
|
/// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
|
||||||
#[clap(long, env, value_enum)]
|
#[clap(long, env, value_enum)]
|
||||||
dtype: Option<Dtype>,
|
dtype: Option<Dtype>,
|
||||||
@ -432,6 +440,11 @@ fn shard_manager(
|
|||||||
shard_args.push(quantize.to_string())
|
shard_args.push(quantize.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(speculate) = speculate {
|
||||||
|
shard_args.push("--speculate".to_string());
|
||||||
|
shard_args.push(speculate.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(dtype) = dtype {
|
if let Some(dtype) = dtype {
|
||||||
shard_args.push("--dtype".to_string());
|
shard_args.push("--dtype".to_string());
|
||||||
shard_args.push(dtype.to_string())
|
shard_args.push(dtype.to_string())
|
||||||
|
@ -32,6 +32,7 @@ def serve(
|
|||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
quantize: Optional[Quantization] = None,
|
quantize: Optional[Quantization] = None,
|
||||||
|
speculate: Optional[int] = None,
|
||||||
dtype: Optional[Dtype] = None,
|
dtype: Optional[Dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
uds_path: Path = "/tmp/text-generation-server",
|
uds_path: Path = "/tmp/text-generation-server",
|
||||||
@ -81,7 +82,7 @@ def serve(
|
|||||||
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||||
)
|
)
|
||||||
server.serve(
|
server.serve(
|
||||||
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
|
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code, uds_path
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -77,15 +77,19 @@ except ImportError as e:
|
|||||||
if MISTRAL:
|
if MISTRAL:
|
||||||
__all__.append(FlashMistral)
|
__all__.append(FlashMistral)
|
||||||
|
|
||||||
|
SPECULATE = None
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool,
|
sharded: bool,
|
||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
|
speculate: Optional[int],
|
||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
|
global SPECULATE
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
# Keep it as default for now and let
|
# Keep it as default for now and let
|
||||||
# every model resolve their own default dtype.
|
# every model resolve their own default dtype.
|
||||||
@ -138,9 +142,18 @@ def get_model(
|
|||||||
medusa_config = config_dict
|
medusa_config = config_dict
|
||||||
model_id = config_dict["base_model_name_or_path"]
|
model_id = config_dict["base_model_name_or_path"]
|
||||||
revision = "main"
|
revision = "main"
|
||||||
|
SPECULATE = config_dict["medusa_num_heads"]
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
|
method = "medusa"
|
||||||
|
else:
|
||||||
|
if speculate is not None:
|
||||||
|
SPECULATE = speculate
|
||||||
|
else:
|
||||||
|
SPECULATE = 2
|
||||||
|
method = "n-gram"
|
||||||
|
logger.info(f"Using speculation {method} with {SPECULATE} input ids.")
|
||||||
|
|
||||||
model_type = config_dict["model_type"]
|
model_type = config_dict["model_type"]
|
||||||
|
|
||||||
|
@ -450,26 +450,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
speculative_ids: Optional[torch.Tensor]
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if speculative_ids is not None:
|
|
||||||
speculative_length = speculative_ids.shape[1]
|
|
||||||
new_length = speculative_length + 1
|
|
||||||
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).squeeze(0)
|
|
||||||
new_position_ids = (position_ids.view((1, -1)).expand(new_length, 1) + torch.arange(new_length).unsqueeze(1).to(device=position_ids.device)).squeeze(0).squeeze(-1)
|
|
||||||
|
|
||||||
# Add an extra block just in case
|
|
||||||
block_tables = torch.cat([block_tables, block_tables[:, -1:] + 1], dim=1)
|
|
||||||
# Add Copy the block tables for all members
|
|
||||||
block_tables = block_tables.expand(new_length, -1).contiguous()
|
|
||||||
slots = slots.expand(new_length) + torch.arange(new_length, dtype=slots.dtype).to(device=slots.device)
|
|
||||||
input_lengths = input_lengths.expand(new_length) + torch.arange(new_length, dtype=input_lengths.dtype).to(device=input_lengths.device)
|
|
||||||
max_s = max_s + speculative_length
|
|
||||||
|
|
||||||
|
|
||||||
input_ids = new_input_ids
|
|
||||||
position_ids = new_position_ids
|
|
||||||
|
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
@ -520,7 +501,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
speculative_ids: Optional[torch.Tensor] = None
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
@ -531,7 +511,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
speculative_ids,
|
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -46,7 +46,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
|
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor]
|
cu_seqlen_prefill: Optional[torch.Tensor]
|
||||||
cu_seqlen_speculative: Optional[torch.Tensor]
|
|
||||||
|
|
||||||
# Paged Attention values
|
# Paged Attention values
|
||||||
|
|
||||||
@ -123,7 +122,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
position_ids = []
|
position_ids = []
|
||||||
speculative_ids = []
|
speculative_ids = []
|
||||||
cu_seqlen_prefill = [0]
|
cu_seqlen_prefill = [0]
|
||||||
cu_seqlen_speculative = [0]
|
|
||||||
needed_blocks_slots = []
|
needed_blocks_slots = []
|
||||||
start_slots = []
|
start_slots = []
|
||||||
slot_indices = []
|
slot_indices = []
|
||||||
@ -163,10 +161,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
tokenized_input = tokenized_input[-r.truncate :]
|
tokenized_input = tokenized_input[-r.truncate :]
|
||||||
|
|
||||||
# # TODO remove this
|
|
||||||
# # Scaffolding to speculate some ids
|
|
||||||
# speculate_ids = [1, 2]
|
|
||||||
# tokenized_input.extend([1, 2])
|
|
||||||
speculate_ids = []
|
speculate_ids = []
|
||||||
|
|
||||||
|
|
||||||
@ -186,7 +180,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Add cumulative lengths of all previous inputs
|
# Add cumulative lengths of all previous inputs
|
||||||
cu_seqlen_prefill.append(cumulative_length + input_length)
|
cu_seqlen_prefill.append(cumulative_length + input_length)
|
||||||
cu_seqlen_speculative.append(cumulative_length + input_length - len(speculate_ids))
|
|
||||||
|
|
||||||
next_token_chooser_parameters.append(r.parameters)
|
next_token_chooser_parameters.append(r.parameters)
|
||||||
|
|
||||||
@ -199,7 +192,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Paged attention
|
# Paged attention
|
||||||
# Remove one as the first token des not have a past
|
# Remove one as the first token des not have a past
|
||||||
total_tokens = input_length + max_new_tokens - 1
|
from text_generation_server.models import SPECULATE
|
||||||
|
speculative_length = SPECULATE
|
||||||
|
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
||||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||||
blocks += needed_blocks
|
blocks += needed_blocks
|
||||||
needed_blocks_slots.append((needed_blocks, total_tokens))
|
needed_blocks_slots.append((needed_blocks, total_tokens))
|
||||||
@ -268,10 +263,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cu_seqlen_prefill = torch.tensor(
|
cu_seqlen_prefill = torch.tensor(
|
||||||
cu_seqlen_prefill, device=device, dtype=torch.int32
|
cu_seqlen_prefill, device=device, dtype=torch.int32
|
||||||
)
|
)
|
||||||
cu_seqlen_speculative = torch.tensor(
|
|
||||||
cu_seqlen_speculative, device=device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
|
|
||||||
position_ids = position_ids.to(device)
|
position_ids = position_ids.to(device)
|
||||||
slot_indices = slot_indices.to(device)
|
slot_indices = slot_indices.to(device)
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
@ -303,7 +294,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
cu_seqlen_speculative=cu_seqlen_speculative,
|
|
||||||
start_slots=start_slots,
|
start_slots=start_slots,
|
||||||
slot_indices=slot_indices,
|
slot_indices=slot_indices,
|
||||||
needed_blocks_slots=needed_blocks_slots,
|
needed_blocks_slots=needed_blocks_slots,
|
||||||
@ -437,6 +427,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
slots = self.slots[slot_filtering_indices]
|
slots = self.slots[slot_filtering_indices]
|
||||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||||
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
||||||
|
speculative_ids = self.speculative_ids[indices]
|
||||||
|
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
@ -472,6 +463,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
|
speculative_ids=speculative_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -595,6 +587,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
device=batches[0].next_token_chooser.device,
|
device=batches[0].next_token_chooser.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
speculative_ids = None if batches[0].speculative_ids is None else torch.cat([b.speculative_ids for b in batches], dim=0)
|
||||||
|
|
||||||
# Needed to avoid dropping blocks when the batches will go out of scope
|
# Needed to avoid dropping blocks when the batches will go out of scope
|
||||||
for b in batches:
|
for b in batches:
|
||||||
b.block_tables = None
|
b.block_tables = None
|
||||||
@ -629,6 +623,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
|
speculative_ids=speculative_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
@ -732,17 +727,55 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
|
if batch.speculative_ids is not None:
|
||||||
|
input_ids=batch.input_ids
|
||||||
|
position_ids=batch.position_ids
|
||||||
|
cu_seqlen_prefill=batch.cu_seqlen_prefill
|
||||||
|
kv_cache=get_cache_manager().kv_cache
|
||||||
|
block_tables=batch.block_tables_tensor
|
||||||
|
slots=batch.slots[batch.slot_indices]
|
||||||
|
input_lengths=batch.input_lengths_tensor
|
||||||
|
max_s=batch.max_seqlen
|
||||||
|
lm_head_indices=batch.prefill_head_indices
|
||||||
|
|
||||||
|
speculative_ids = batch.speculative_ids
|
||||||
|
|
||||||
|
B, speculative_length = speculative_ids.shape
|
||||||
|
new_length = speculative_length + 1
|
||||||
|
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1)
|
||||||
|
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
||||||
|
arange_int = arange.to(dtype=torch.int32)
|
||||||
|
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1)
|
||||||
|
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||||
|
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||||
|
|
||||||
|
# Add Copy the block tables for all members
|
||||||
|
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous()
|
||||||
|
max_s = max_s + speculative_length
|
||||||
|
|
||||||
|
input_ids = new_input_ids
|
||||||
|
position_ids = new_position_ids
|
||||||
|
else:
|
||||||
|
input_ids=batch.input_ids
|
||||||
|
position_ids=batch.position_ids
|
||||||
|
cu_seqlen_prefill=batch.cu_seqlen_prefill
|
||||||
|
kv_cache=get_cache_manager().kv_cache
|
||||||
|
block_tables=batch.block_tables_tensor
|
||||||
|
slots=batch.slots[batch.slot_indices]
|
||||||
|
input_lengths=batch.input_lengths_tensor
|
||||||
|
max_s=batch.max_seqlen
|
||||||
|
lm_head_indices=batch.prefill_head_indices
|
||||||
|
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
input_ids=batch.input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=batch.position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=batch.cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
kv_cache=get_cache_manager().kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=batch.block_tables_tensor,
|
block_tables=block_tables,
|
||||||
slots=batch.slots[batch.slot_indices],
|
slots=slots,
|
||||||
input_lengths=batch.input_lengths_tensor,
|
input_lengths=input_lengths,
|
||||||
max_s=batch.max_seqlen,
|
max_s=max_s,
|
||||||
lm_head_indices=batch.prefill_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
speculative_ids =batch.speculative_ids
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
@ -792,8 +825,9 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
# if next_token_logits.shape[0] == 3:
|
# if next_token_logits.shape[0] == 3:
|
||||||
# import ipdb;ipdb.set_trace()
|
# import ipdb;ipdb.set_trace()
|
||||||
|
from text_generation_server.models import SPECULATE
|
||||||
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
|
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
|
||||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, batch.speculative_ids, speculative_logits
|
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, SPECULATE, batch.speculative_ids, speculative_logits
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
@ -807,14 +841,8 @@ class FlashCausalLM(Model):
|
|||||||
# When batch == 1, we will just use the batch.input_ids values directly
|
# When batch == 1, we will just use the batch.input_ids values directly
|
||||||
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
||||||
|
|
||||||
if speculative_ids is not None:
|
length = len(batch)
|
||||||
# length = len(batch) * (1 + speculative_length)
|
|
||||||
length = len(batch)
|
|
||||||
else:
|
|
||||||
length = len(batch)
|
|
||||||
# import ipdb;ipdb.set_trace()
|
|
||||||
next_position_ids = batch.position_ids.new_empty(length)
|
next_position_ids = batch.position_ids.new_empty(length)
|
||||||
# Keep only 1 slot index, TODO make sure we recover the speculated ids slots later
|
|
||||||
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
|
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
|
||||||
# We do not need cu_seqlen_prefill anymore
|
# We do not need cu_seqlen_prefill anymore
|
||||||
batch.cu_seqlen_prefill = None
|
batch.cu_seqlen_prefill = None
|
||||||
@ -885,19 +913,17 @@ class FlashCausalLM(Model):
|
|||||||
# if accepted_ids[0] > 1:
|
# if accepted_ids[0] > 1:
|
||||||
# import ipdb;ipdb.set_trace()
|
# import ipdb;ipdb.set_trace()
|
||||||
|
|
||||||
if len(accepted_ids) > 1:
|
# if len(accepted_ids) > 1:
|
||||||
raise Exception("Implemtent the batched behavior")
|
# raise Exception("Implemtent the batched behavior")
|
||||||
|
|
||||||
# Set values in batch
|
# Set values in batch
|
||||||
# batch.input_ids = torch.cat([next_input_ids.unsqueeze(-1), speculative_ids], dim=1).view(-1)
|
# batch.input_ids = torch.cat([next_input_ids.unsqueeze(-1), speculative_ids], dim=1).view(-1)
|
||||||
|
|
||||||
for n_accepted_ids in accepted_ids:
|
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
||||||
# TODO Make this batched
|
batch.speculative_ids = speculative_ids
|
||||||
batch.input_ids = next_input_ids[-1:]
|
batch.position_ids = next_position_ids + accepted_ids
|
||||||
batch.speculative_ids = speculative_ids
|
batch.input_lengths_tensor += accepted_ids
|
||||||
batch.position_ids = next_position_ids + n_accepted_ids
|
batch.slot_indices += accepted_ids
|
||||||
batch.input_lengths_tensor += n_accepted_ids
|
|
||||||
batch.slot_indices += n_accepted_ids
|
|
||||||
|
|
||||||
if prefill and prefill_logprobs:
|
if prefill and prefill_logprobs:
|
||||||
# Get prefill logprobs
|
# Get prefill logprobs
|
||||||
@ -962,6 +988,7 @@ class FlashCausalLM(Model):
|
|||||||
read_offset,
|
read_offset,
|
||||||
)
|
)
|
||||||
next_token_texts.append(next_token_text)
|
next_token_texts.append(next_token_text)
|
||||||
|
index += n_accepted_ids
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
|
|
||||||
|
@ -272,6 +272,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
|
speculative_ids=None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -132,6 +132,7 @@ def serve(
|
|||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool,
|
sharded: bool,
|
||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
|
speculate: Optional[int],
|
||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
uds_path: Path,
|
uds_path: Path,
|
||||||
@ -141,6 +142,7 @@ def serve(
|
|||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
speculate: Optional[int] = None,
|
||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -157,7 +159,7 @@ def serve(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
model = get_model(
|
model = get_model(
|
||||||
model_id, revision, sharded, quantize, dtype, trust_remote_code
|
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error when initializing model")
|
logger.exception("Error when initializing model")
|
||||||
@ -205,5 +207,5 @@ def serve(
|
|||||||
await server.stop(0)
|
await server.stop(0)
|
||||||
|
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code)
|
serve_inner(model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code)
|
||||||
)
|
)
|
||||||
|
@ -147,6 +147,49 @@ class StoppingCriteria:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def longest_match(input_ids: List[int]) -> Optional[int]:
|
||||||
|
longest_match = 0
|
||||||
|
seed = input_ids[-1]
|
||||||
|
final_matches = []
|
||||||
|
current_matches = []
|
||||||
|
for i in range(1, len(input_ids)):
|
||||||
|
index = len(input_ids) - i - 1
|
||||||
|
|
||||||
|
_current_matches = []
|
||||||
|
for (_index, length) in current_matches:
|
||||||
|
if input_ids[index] == input_ids[len(input_ids) - length - 1]:
|
||||||
|
_current_matches.append((_index, length + 1))
|
||||||
|
elif length > longest_match:
|
||||||
|
longest_match = length
|
||||||
|
final_matches.append((_index, length))
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
current_matches = _current_matches
|
||||||
|
|
||||||
|
if input_ids[index] == seed:
|
||||||
|
current_matches.append( (index, 1) )
|
||||||
|
if not final_matches:
|
||||||
|
return 0
|
||||||
|
return final_matches[-1][0]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int):
|
||||||
|
B = accepted_ids.shape[0]
|
||||||
|
device = input_ids.device
|
||||||
|
dtype = input_ids.dtype
|
||||||
|
speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype)
|
||||||
|
input_ids = input_ids.tolist()
|
||||||
|
|
||||||
|
index = 0
|
||||||
|
for i, (_input_ids, n_accepted_ids) in enumerate(zip(input_ids, accepted_ids.tolist())):
|
||||||
|
_input_ids.extend(next_ids[index: index + n_accepted_ids].tolist())
|
||||||
|
index = longest_match(_input_ids) + 1
|
||||||
|
ids = _input_ids[index:index+speculate]
|
||||||
|
speculative_ids[i, :len(ids)] = torch.tensor(ids, device=device, dtype=dtype)
|
||||||
|
index += n_accepted_ids
|
||||||
|
return speculative_ids
|
||||||
|
|
||||||
class HeterogeneousNextTokenChooser:
|
class HeterogeneousNextTokenChooser:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -215,7 +258,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None):
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None):
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
scores = self.watermark_processor(input_ids, scores)
|
scores = self.watermark_processor(input_ids, scores)
|
||||||
if self.repetition_processor is not None:
|
if self.repetition_processor is not None:
|
||||||
@ -225,40 +268,51 @@ class HeterogeneousNextTokenChooser:
|
|||||||
scores = warper(input_ids, scores)
|
scores = warper(input_ids, scores)
|
||||||
|
|
||||||
|
|
||||||
accepted_ids = []
|
|
||||||
next_ids = self.choice(scores)
|
next_ids = self.choice(scores)
|
||||||
if speculated_ids is not None:
|
if speculated_ids is not None:
|
||||||
validate_speculative = next_ids[:-1] == speculated_ids[0]
|
accepted_ids = []
|
||||||
index = 1
|
B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
|
||||||
for valid in validate_speculative.tolist():
|
S = speculated_ids.shape[1] + 1
|
||||||
if valid:
|
indices = []
|
||||||
index += 1
|
for i in range(B):
|
||||||
# print(f"Validated {index - 1}")
|
_next_ids = next_ids[i*S: (i + 1)*S]
|
||||||
next_ids = next_ids[:index]
|
_speculated_ids = speculated_ids[i]
|
||||||
scores = scores[:index]
|
validate_speculative = _next_ids[:-1] == _speculated_ids
|
||||||
speculative_scores = speculative_scores[index - 1:index]
|
index = i * S
|
||||||
accepted_ids.append(index)
|
accepted = 1
|
||||||
|
# First is always valid
|
||||||
|
indices.append(index)
|
||||||
|
for valid in validate_speculative.tolist():
|
||||||
|
if valid:
|
||||||
|
index += 1
|
||||||
|
accepted += 1
|
||||||
|
indices.append(index)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
# if accepted > 1:
|
||||||
|
# import ipdb;ipdb.set_trace()
|
||||||
|
accepted_ids.append(accepted)
|
||||||
|
accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype)
|
||||||
|
next_ids = next_ids[indices]
|
||||||
|
scores = scores[indices]
|
||||||
|
indices = torch.arange(B, device=input_ids.device) * S
|
||||||
|
if speculative_scores is not None:
|
||||||
|
speculative_scores = speculative_scores[indices + accepted_ids - 1]
|
||||||
else:
|
else:
|
||||||
accepted_ids.append(1)
|
accepted_ids = torch.ones_like(next_ids)
|
||||||
|
|
||||||
|
|
||||||
logprobs = torch.log_softmax(scores, -1)
|
logprobs = torch.log_softmax(scores, -1)
|
||||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||||
|
|
||||||
if speculative_scores is not None:
|
if speculate > 0:
|
||||||
# length, spec_length, vocab_size = speculative_scores.shape
|
if speculative_scores is not None:
|
||||||
# speculative_scores = speculative_scores.view((-1, vocab_size))
|
# TODO This will only speculate the top score
|
||||||
# if self.watermark_processor is not None:
|
# Medusa provided some scores
|
||||||
# speculative_scores = self.watermark_processor(input_ids, speculative_scores)
|
speculative_ids = Greedy()(speculative_scores)
|
||||||
# if self.repetition_processor is not None:
|
else:
|
||||||
# speculative_scores = self.repetition_processor(input_ids, speculative_scores)
|
# n-gram
|
||||||
|
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate)
|
||||||
# speculative_scores = speculative_scores.view((length, spec_length, vocab_size))
|
|
||||||
# for warper in self.warpers:
|
|
||||||
# speculative_scores = warper(input_ids, speculative_scores)
|
|
||||||
speculative_ids = Greedy()(speculative_scores)
|
|
||||||
# # Ignore first head, it seems to be a regular head.
|
|
||||||
# speculative_ids = speculative_ids[:, 1:]
|
|
||||||
else:
|
else:
|
||||||
speculative_ids = None
|
speculative_ids = None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user