mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix speculation
This commit is contained in:
parent
2b25e9a94e
commit
b4ebfa52f4
@ -2023,8 +2023,8 @@ class FlashCausalLM(Model):
|
|||||||
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
||||||
batch.speculative_ids = speculative_ids
|
batch.speculative_ids = speculative_ids
|
||||||
batch.position_ids += accepted_ids
|
batch.position_ids += accepted_ids
|
||||||
batch.cache_lengths_tensor += batch.input_lengths_tensor
|
batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1
|
||||||
batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32)
|
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
|
||||||
batch.slot_indices += accepted_ids
|
batch.slot_indices += accepted_ids
|
||||||
|
|
||||||
if prefill and prefill_logprobs:
|
if prefill and prefill_logprobs:
|
||||||
@ -2202,8 +2202,10 @@ class FlashCausalLM(Model):
|
|||||||
# processing
|
# processing
|
||||||
stopped = False
|
stopped = False
|
||||||
new_input_length = next_chunk_lengths[i]
|
new_input_length = next_chunk_lengths[i]
|
||||||
|
new_cache_length = cache_length + input_length
|
||||||
else:
|
else:
|
||||||
new_input_length = n_accepted_ids
|
new_input_length = 1
|
||||||
|
new_cache_length = cache_length + input_length + n_accepted_ids - 1
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
next_token_texts = []
|
next_token_texts = []
|
||||||
left = 0
|
left = 0
|
||||||
@ -2315,12 +2317,10 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
index += n_accepted_ids
|
index += n_accepted_ids
|
||||||
current_cache_length = cache_length + input_length
|
batch.cache_lengths[i] = new_cache_length
|
||||||
batch.cache_lengths[i] = current_cache_length
|
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||||
current_input_length = new_input_length
|
batch.input_lengths[i] = new_input_length
|
||||||
batch.max_input_length = max(batch.max_input_length, current_input_length)
|
current_length = new_cache_length + new_input_length
|
||||||
batch.input_lengths[i] = current_input_length
|
|
||||||
current_length = current_cache_length + current_input_length
|
|
||||||
batch.max_current_length = max(batch.max_current_length, current_length)
|
batch.max_current_length = max(batch.max_current_length, current_length)
|
||||||
|
|
||||||
batch.prefix_offsets[i] = prefix_offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
|
Loading…
Reference in New Issue
Block a user