mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix speculation
This commit is contained in:
parent
2b25e9a94e
commit
b4ebfa52f4
@ -2023,8 +2023,8 @@ class FlashCausalLM(Model):
|
||||
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
||||
batch.speculative_ids = speculative_ids
|
||||
batch.position_ids += accepted_ids
|
||||
batch.cache_lengths_tensor += batch.input_lengths_tensor
|
||||
batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32)
|
||||
batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1
|
||||
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
|
||||
batch.slot_indices += accepted_ids
|
||||
|
||||
if prefill and prefill_logprobs:
|
||||
@ -2202,8 +2202,10 @@ class FlashCausalLM(Model):
|
||||
# processing
|
||||
stopped = False
|
||||
new_input_length = next_chunk_lengths[i]
|
||||
new_cache_length = cache_length + input_length
|
||||
else:
|
||||
new_input_length = n_accepted_ids
|
||||
new_input_length = 1
|
||||
new_cache_length = cache_length + input_length + n_accepted_ids - 1
|
||||
# Append next token to all tokens
|
||||
next_token_texts = []
|
||||
left = 0
|
||||
@ -2315,12 +2317,10 @@ class FlashCausalLM(Model):
|
||||
|
||||
# Update values
|
||||
index += n_accepted_ids
|
||||
current_cache_length = cache_length + input_length
|
||||
batch.cache_lengths[i] = current_cache_length
|
||||
current_input_length = new_input_length
|
||||
batch.max_input_length = max(batch.max_input_length, current_input_length)
|
||||
batch.input_lengths[i] = current_input_length
|
||||
current_length = current_cache_length + current_input_length
|
||||
batch.cache_lengths[i] = new_cache_length
|
||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||
batch.input_lengths[i] = new_input_length
|
||||
current_length = new_cache_length + new_input_length
|
||||
batch.max_current_length = max(batch.max_current_length, current_length)
|
||||
|
||||
batch.prefix_offsets[i] = prefix_offset
|
||||
|
Loading…
Reference in New Issue
Block a user