Medusa + ngram

This commit is contained in:
Nicolas Patry 2023-12-01 17:57:20 +00:00
parent b4d97d52cd
commit 657ccd8276
12 changed files with 740 additions and 91 deletions

View File

@ -0,0 +1,58 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "stop_sequence",
"generated_tokens": 5,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -10.0625,
"text": "Test"
},
{
"id": 2009,
"logprob": -12.28125,
"text": "request"
}
],
"seed": 0,
"tokens": [
{
"id": 5229,
"logprob": -1.7587891,
"special": false,
"text": " failed"
},
{
"id": 363,
"logprob": -0.5175781,
"special": false,
"text": " for"
},
{
"id": 1404,
"logprob": 0.0,
"special": false,
"text": " user"
},
{
"id": 376,
"logprob": 0.0,
"special": false,
"text": " \""
},
{
"id": 1688,
"logprob": -0.20422363,
"special": false,
"text": "test"
}
]
},
"generated_text": "Test request failed for user \"test"
}

View File

@ -0,0 +1,354 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -10.0625,
"text": "Test"
},
{
"id": 2009,
"logprob": -12.28125,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 363,
"logprob": -2.0878906,
"special": false,
"text": " for"
},
{
"id": 278,
"logprob": -3.4082031,
"special": false,
"text": " the"
},
{
"id": 376,
"logprob": -3.8457031,
"special": false,
"text": " \""
},
{
"id": 2577,
"logprob": -3.5605469,
"special": false,
"text": "Get"
},
{
"id": 599,
"logprob": -3.4707031,
"special": false,
"text": " all"
},
{
"id": 4160,
"logprob": -3.2421875,
"special": false,
"text": " users"
},
{
"id": 29908,
"logprob": -0.49072266,
"special": false,
"text": "\""
},
{
"id": 16248,
"logprob": -1.2353516,
"special": false,
"text": " endpoint"
},
{
"id": 29889,
"logprob": -0.8833008,
"special": false,
"text": "."
},
{
"id": 13,
"logprob": -0.42089844,
"special": false,
"text": "\n"
}
]
},
"generated_text": " for the \"Get all users\" endpoint.\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -10.0625,
"text": "Test"
},
{
"id": 2009,
"logprob": -12.28125,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 363,
"logprob": -2.0878906,
"special": false,
"text": " for"
},
{
"id": 278,
"logprob": -3.4082031,
"special": false,
"text": " the"
},
{
"id": 376,
"logprob": -3.8457031,
"special": false,
"text": " \""
},
{
"id": 2577,
"logprob": -3.5625,
"special": false,
"text": "Get"
},
{
"id": 599,
"logprob": -3.4726562,
"special": false,
"text": " all"
},
{
"id": 4160,
"logprob": -3.2382812,
"special": false,
"text": " users"
},
{
"id": 29908,
"logprob": -0.49047852,
"special": false,
"text": "\""
},
{
"id": 16248,
"logprob": -1.2412109,
"special": false,
"text": " endpoint"
},
{
"id": 29889,
"logprob": -0.87402344,
"special": false,
"text": "."
},
{
"id": 13,
"logprob": -0.41723633,
"special": false,
"text": "\n"
}
]
},
"generated_text": " for the \"Get all users\" endpoint.\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -10.0625,
"text": "Test"
},
{
"id": 2009,
"logprob": -12.28125,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 363,
"logprob": -2.0878906,
"special": false,
"text": " for"
},
{
"id": 278,
"logprob": -3.4082031,
"special": false,
"text": " the"
},
{
"id": 376,
"logprob": -3.8457031,
"special": false,
"text": " \""
},
{
"id": 2577,
"logprob": -3.5605469,
"special": false,
"text": "Get"
},
{
"id": 599,
"logprob": -3.4707031,
"special": false,
"text": " all"
},
{
"id": 4160,
"logprob": -3.2421875,
"special": false,
"text": " users"
},
{
"id": 29908,
"logprob": -0.49072266,
"special": false,
"text": "\""
},
{
"id": 16248,
"logprob": -1.2353516,
"special": false,
"text": " endpoint"
},
{
"id": 29889,
"logprob": -0.8833008,
"special": false,
"text": "."
},
{
"id": 13,
"logprob": -0.42089844,
"special": false,
"text": "\n"
}
]
},
"generated_text": " for the \"Get all users\" endpoint.\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -10.0625,
"text": "Test"
},
{
"id": 2009,
"logprob": -12.28125,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 363,
"logprob": -2.0878906,
"special": false,
"text": " for"
},
{
"id": 278,
"logprob": -3.4082031,
"special": false,
"text": " the"
},
{
"id": 376,
"logprob": -3.8457031,
"special": false,
"text": " \""
},
{
"id": 2577,
"logprob": -3.5605469,
"special": false,
"text": "Get"
},
{
"id": 599,
"logprob": -3.4707031,
"special": false,
"text": " all"
},
{
"id": 4160,
"logprob": -3.2421875,
"special": false,
"text": " users"
},
{
"id": 29908,
"logprob": -0.49072266,
"special": false,
"text": "\""
},
{
"id": 16248,
"logprob": -1.2353516,
"special": false,
"text": " endpoint"
},
{
"id": 29889,
"logprob": -0.8833008,
"special": false,
"text": "."
},
{
"id": 13,
"logprob": -0.42089844,
"special": false,
"text": "\n"
}
]
},
"generated_text": " for the \"Get all users\" endpoint.\n"
}
]

View File

@ -0,0 +1,88 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -10.0625,
"text": "Test"
},
{
"id": 2009,
"logprob": -12.28125,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 363,
"logprob": -2.0878906,
"special": false,
"text": " for"
},
{
"id": 278,
"logprob": -3.4121094,
"special": false,
"text": " the"
},
{
"id": 376,
"logprob": -3.8457031,
"special": false,
"text": " \""
},
{
"id": 2577,
"logprob": -3.5566406,
"special": false,
"text": "Get"
},
{
"id": 599,
"logprob": -3.4746094,
"special": false,
"text": " all"
},
{
"id": 4160,
"logprob": -3.2363281,
"special": false,
"text": " users"
},
{
"id": 29908,
"logprob": -0.49023438,
"special": false,
"text": "\""
},
{
"id": 16248,
"logprob": -1.2402344,
"special": false,
"text": " endpoint"
},
{
"id": 29889,
"logprob": -0.88134766,
"special": false,
"text": "."
},
{
"id": 13,
"logprob": -0.41870117,
"special": false,
"text": "\n"
}
]
},
"generated_text": " for the \"Get all users\" endpoint.\n"
}

View File

@ -0,0 +1,59 @@
import pytest
@pytest.fixture(scope="module")
def flash_medusa_handle(launcher):
with launcher("FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_medusa(flash_medusa_handle):
await flash_medusa_handle.health(300)
return flash_medusa_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_medusa_simple(flash_medusa, response_snapshot):
response = await flash_medusa.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
response = await flash_medusa.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 5
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):
responses = await generate_load(flash_medusa, "Test request", max_new_tokens=10, n=4)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
assert responses[0].generated_text == ' for the "Get all users" endpoint.\n'
assert responses == response_snapshot

View File

@ -155,6 +155,14 @@ struct Args {
#[clap(long, env, value_enum)]
quantize: Option<Quantization>,
/// The number of input_ids to speculate on
/// If using a medusa model, the heads will be picked up automatically
/// Other wise, it will use n-gram speculation which is relatively free
/// in terms of compute, but the speedup heavily depends on the task.
#[clap(long, env)]
speculate: Option<usize>,
/// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
#[clap(long, env, value_enum)]
dtype: Option<Dtype>,
@ -432,6 +440,11 @@ fn shard_manager(
shard_args.push(quantize.to_string())
}
if let Some(speculate) = speculate {
shard_args.push("--speculate".to_string());
shard_args.push(speculate.to_string())
}
if let Some(dtype) = dtype {
shard_args.push("--dtype".to_string());
shard_args.push(dtype.to_string())

View File

@ -32,6 +32,7 @@ def serve(
revision: Optional[str] = None,
sharded: bool = False,
quantize: Optional[Quantization] = None,
speculate: Optional[int] = None,
dtype: Optional[Dtype] = None,
trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server",
@ -81,7 +82,7 @@ def serve(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)
server.serve(
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code, uds_path
)

View File

@ -77,15 +77,19 @@ except ImportError as e:
if MISTRAL:
__all__.append(FlashMistral)
SPECULATE = None
def get_model(
model_id: str,
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[str],
trust_remote_code: bool,
) -> Model:
global SPECULATE
if dtype is None:
# Keep it as default for now and let
# every model resolve their own default dtype.
@ -138,9 +142,18 @@ def get_model(
medusa_config = config_dict
model_id = config_dict["base_model_name_or_path"]
revision = "main"
SPECULATE = config_dict["medusa_num_heads"]
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
method = "medusa"
else:
if speculate is not None:
SPECULATE = speculate
else:
SPECULATE = 2
method = "n-gram"
logger.info(f"Using speculation {method} with {SPECULATE} input ids.")
model_type = config_dict["model_type"]

View File

@ -450,26 +450,7 @@ class FlashLlamaModel(torch.nn.Module):
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
speculative_ids: Optional[torch.Tensor]
) -> torch.Tensor:
if speculative_ids is not None:
speculative_length = speculative_ids.shape[1]
new_length = speculative_length + 1
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).squeeze(0)
new_position_ids = (position_ids.view((1, -1)).expand(new_length, 1) + torch.arange(new_length).unsqueeze(1).to(device=position_ids.device)).squeeze(0).squeeze(-1)
# Add an extra block just in case
block_tables = torch.cat([block_tables, block_tables[:, -1:] + 1], dim=1)
# Add Copy the block tables for all members
block_tables = block_tables.expand(new_length, -1).contiguous()
slots = slots.expand(new_length) + torch.arange(new_length, dtype=slots.dtype).to(device=slots.device)
input_lengths = input_lengths.expand(new_length) + torch.arange(new_length, dtype=input_lengths.dtype).to(device=input_lengths.device)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
@ -520,7 +501,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
input_lengths: torch.Tensor,
max_s: int,
lm_head_indices: Optional[torch.Tensor] = None,
speculative_ids: Optional[torch.Tensor] = None
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
@ -531,7 +511,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
slots,
input_lengths,
max_s,
speculative_ids,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -46,7 +46,6 @@ class FlashCausalLMBatch(Batch):
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
cu_seqlen_prefill: Optional[torch.Tensor]
cu_seqlen_speculative: Optional[torch.Tensor]
# Paged Attention values
@ -123,7 +122,6 @@ class FlashCausalLMBatch(Batch):
position_ids = []
speculative_ids = []
cu_seqlen_prefill = [0]
cu_seqlen_speculative = [0]
needed_blocks_slots = []
start_slots = []
slot_indices = []
@ -163,10 +161,6 @@ class FlashCausalLMBatch(Batch):
tokenized_input = tokenized_input[-r.truncate :]
# # TODO remove this
# # Scaffolding to speculate some ids
# speculate_ids = [1, 2]
# tokenized_input.extend([1, 2])
speculate_ids = []
@ -186,7 +180,6 @@ class FlashCausalLMBatch(Batch):
# Add cumulative lengths of all previous inputs
cu_seqlen_prefill.append(cumulative_length + input_length)
cu_seqlen_speculative.append(cumulative_length + input_length - len(speculate_ids))
next_token_chooser_parameters.append(r.parameters)
@ -199,7 +192,9 @@ class FlashCausalLMBatch(Batch):
# Paged attention
# Remove one as the first token des not have a past
total_tokens = input_length + max_new_tokens - 1
from text_generation_server.models import SPECULATE
speculative_length = SPECULATE
total_tokens = input_length + max_new_tokens - 1 + speculative_length
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
blocks += needed_blocks
needed_blocks_slots.append((needed_blocks, total_tokens))
@ -268,10 +263,6 @@ class FlashCausalLMBatch(Batch):
cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32
)
cu_seqlen_speculative = torch.tensor(
cu_seqlen_speculative, device=device, dtype=torch.int32
)
position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
@ -303,7 +294,6 @@ class FlashCausalLMBatch(Batch):
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
cu_seqlen_speculative=cu_seqlen_speculative,
start_slots=start_slots,
slot_indices=slot_indices,
needed_blocks_slots=needed_blocks_slots,
@ -437,6 +427,7 @@ class FlashCausalLMBatch(Batch):
slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
speculative_ids = self.speculative_ids[indices]
start_slots = torch.tensor(start_slots, dtype=torch.int64)
@ -472,6 +463,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids,
)
@classmethod
@ -595,6 +587,8 @@ class FlashCausalLMBatch(Batch):
device=batches[0].next_token_chooser.device,
)
speculative_ids = None if batches[0].speculative_ids is None else torch.cat([b.speculative_ids for b in batches], dim=0)
# Needed to avoid dropping blocks when the batches will go out of scope
for b in batches:
b.block_tables = None
@ -629,6 +623,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids
)
def __del__(self):
@ -732,17 +727,55 @@ class FlashCausalLM(Model):
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward
return self.model.forward(
input_ids=batch.input_ids,
position_ids=batch.position_ids,
cu_seqlen_prefill=batch.cu_seqlen_prefill,
kv_cache=get_cache_manager().kv_cache,
block_tables=batch.block_tables_tensor,
slots=batch.slots[batch.slot_indices],
input_lengths=batch.input_lengths_tensor,
max_s=batch.max_seqlen,
lm_head_indices=batch.prefill_head_indices,
if batch.speculative_ids is not None:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
# Add Copy the block tables for all members
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous()
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
return self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
lm_head_indices=lm_head_indices,
)
@tracer.start_as_current_span("generate_token")
@ -792,8 +825,9 @@ class FlashCausalLM(Model):
# if next_token_logits.shape[0] == 3:
# import ipdb;ipdb.set_trace()
from text_generation_server.models import SPECULATE
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, batch.speculative_ids, speculative_logits
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, SPECULATE, batch.speculative_ids, speculative_logits
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
@ -807,14 +841,8 @@ class FlashCausalLM(Model):
# When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
if speculative_ids is not None:
# length = len(batch) * (1 + speculative_length)
length = len(batch)
else:
length = len(batch)
# import ipdb;ipdb.set_trace()
next_position_ids = batch.position_ids.new_empty(length)
# Keep only 1 slot index, TODO make sure we recover the speculated ids slots later
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
# We do not need cu_seqlen_prefill anymore
batch.cu_seqlen_prefill = None
@ -885,19 +913,17 @@ class FlashCausalLM(Model):
# if accepted_ids[0] > 1:
# import ipdb;ipdb.set_trace()
if len(accepted_ids) > 1:
raise Exception("Implemtent the batched behavior")
# if len(accepted_ids) > 1:
# raise Exception("Implemtent the batched behavior")
# Set values in batch
# batch.input_ids = torch.cat([next_input_ids.unsqueeze(-1), speculative_ids], dim=1).view(-1)
for n_accepted_ids in accepted_ids:
# TODO Make this batched
batch.input_ids = next_input_ids[-1:]
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + n_accepted_ids
batch.input_lengths_tensor += n_accepted_ids
batch.slot_indices += n_accepted_ids
batch.position_ids = next_position_ids + accepted_ids
batch.input_lengths_tensor += accepted_ids
batch.slot_indices += accepted_ids
if prefill and prefill_logprobs:
# Get prefill logprobs
@ -962,6 +988,7 @@ class FlashCausalLM(Model):
read_offset,
)
next_token_texts.append(next_token_text)
index += n_accepted_ids
# Evaluate stopping criteria

View File

@ -272,6 +272,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
blocks=blocks,
max_blocks=max_blocks,
prefill_cache_indices=prefill_cache_indices,
speculative_ids=None
)

View File

@ -132,6 +132,7 @@ def serve(
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[str],
trust_remote_code: bool,
uds_path: Path,
@ -141,6 +142,7 @@ def serve(
revision: Optional[str],
sharded: bool = False,
quantize: Optional[str] = None,
speculate: Optional[int] = None,
dtype: Optional[str] = None,
trust_remote_code: bool = False,
):
@ -157,7 +159,7 @@ def serve(
try:
model = get_model(
model_id, revision, sharded, quantize, dtype, trust_remote_code
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
)
except Exception:
logger.exception("Error when initializing model")
@ -205,5 +207,5 @@ def serve(
await server.stop(0)
asyncio.run(
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code)
serve_inner(model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code)
)

View File

@ -147,6 +147,49 @@ class StoppingCriteria:
)
def longest_match(input_ids: List[int]) -> Optional[int]:
longest_match = 0
seed = input_ids[-1]
final_matches = []
current_matches = []
for i in range(1, len(input_ids)):
index = len(input_ids) - i - 1
_current_matches = []
for (_index, length) in current_matches:
if input_ids[index] == input_ids[len(input_ids) - length - 1]:
_current_matches.append((_index, length + 1))
elif length > longest_match:
longest_match = length
final_matches.append((_index, length))
else:
pass
current_matches = _current_matches
if input_ids[index] == seed:
current_matches.append( (index, 1) )
if not final_matches:
return 0
return final_matches[-1][0]
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int):
B = accepted_ids.shape[0]
device = input_ids.device
dtype = input_ids.dtype
speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype)
input_ids = input_ids.tolist()
index = 0
for i, (_input_ids, n_accepted_ids) in enumerate(zip(input_ids, accepted_ids.tolist())):
_input_ids.extend(next_ids[index: index + n_accepted_ids].tolist())
index = longest_match(_input_ids) + 1
ids = _input_ids[index:index+speculate]
speculative_ids[i, :len(ids)] = torch.tensor(ids, device=device, dtype=dtype)
index += n_accepted_ids
return speculative_ids
class HeterogeneousNextTokenChooser:
def __init__(
self,
@ -215,7 +258,7 @@ class HeterogeneousNextTokenChooser:
self.dtype = dtype
self.device = device
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None):
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None):
if self.watermark_processor is not None:
scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None:
@ -225,40 +268,51 @@ class HeterogeneousNextTokenChooser:
scores = warper(input_ids, scores)
accepted_ids = []
next_ids = self.choice(scores)
if speculated_ids is not None:
validate_speculative = next_ids[:-1] == speculated_ids[0]
index = 1
accepted_ids = []
B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
S = speculated_ids.shape[1] + 1
indices = []
for i in range(B):
_next_ids = next_ids[i*S: (i + 1)*S]
_speculated_ids = speculated_ids[i]
validate_speculative = _next_ids[:-1] == _speculated_ids
index = i * S
accepted = 1
# First is always valid
indices.append(index)
for valid in validate_speculative.tolist():
if valid:
index += 1
# print(f"Validated {index - 1}")
next_ids = next_ids[:index]
scores = scores[:index]
speculative_scores = speculative_scores[index - 1:index]
accepted_ids.append(index)
accepted += 1
indices.append(index)
else:
accepted_ids.append(1)
break
# if accepted > 1:
# import ipdb;ipdb.set_trace()
accepted_ids.append(accepted)
accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype)
next_ids = next_ids[indices]
scores = scores[indices]
indices = torch.arange(B, device=input_ids.device) * S
if speculative_scores is not None:
speculative_scores = speculative_scores[indices + accepted_ids - 1]
else:
accepted_ids = torch.ones_like(next_ids)
logprobs = torch.log_softmax(scores, -1)
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
if speculate > 0:
if speculative_scores is not None:
# length, spec_length, vocab_size = speculative_scores.shape
# speculative_scores = speculative_scores.view((-1, vocab_size))
# if self.watermark_processor is not None:
# speculative_scores = self.watermark_processor(input_ids, speculative_scores)
# if self.repetition_processor is not None:
# speculative_scores = self.repetition_processor(input_ids, speculative_scores)
# speculative_scores = speculative_scores.view((length, spec_length, vocab_size))
# for warper in self.warpers:
# speculative_scores = warper(input_ids, speculative_scores)
# TODO This will only speculate the top score
# Medusa provided some scores
speculative_ids = Greedy()(speculative_scores)
# # Ignore first head, it seems to be a regular head.
# speculative_ids = speculative_ids[:, 1:]
else:
# n-gram
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate)
else:
speculative_ids = None