fix speculation

This commit is contained in:
OlivierDehaene 2024-10-25 21:13:24 +02:00
parent 2b25e9a94e
commit b4ebfa52f4
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -2023,8 +2023,8 @@ class FlashCausalLM(Model):
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
batch.speculative_ids = speculative_ids batch.speculative_ids = speculative_ids
batch.position_ids += accepted_ids batch.position_ids += accepted_ids
batch.cache_lengths_tensor += batch.input_lengths_tensor batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1
batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32) batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
batch.slot_indices += accepted_ids batch.slot_indices += accepted_ids
if prefill and prefill_logprobs: if prefill and prefill_logprobs:
@ -2202,8 +2202,10 @@ class FlashCausalLM(Model):
# processing # processing
stopped = False stopped = False
new_input_length = next_chunk_lengths[i] new_input_length = next_chunk_lengths[i]
new_cache_length = cache_length + input_length
else: else:
new_input_length = n_accepted_ids new_input_length = 1
new_cache_length = cache_length + input_length + n_accepted_ids - 1
# Append next token to all tokens # Append next token to all tokens
next_token_texts = [] next_token_texts = []
left = 0 left = 0
@ -2315,12 +2317,10 @@ class FlashCausalLM(Model):
# Update values # Update values
index += n_accepted_ids index += n_accepted_ids
current_cache_length = cache_length + input_length batch.cache_lengths[i] = new_cache_length
batch.cache_lengths[i] = current_cache_length batch.max_input_length = max(batch.max_input_length, new_input_length)
current_input_length = new_input_length batch.input_lengths[i] = new_input_length
batch.max_input_length = max(batch.max_input_length, current_input_length) current_length = new_cache_length + new_input_length
batch.input_lengths[i] = current_input_length
current_length = current_cache_length + current_input_length
batch.max_current_length = max(batch.max_current_length, current_length) batch.max_current_length = max(batch.max_current_length, current_length)
batch.prefix_offsets[i] = prefix_offset batch.prefix_offsets[i] = prefix_offset