From b4ebfa52f48a3564fbc503e4ef34d2aeb9cff460 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 25 Oct 2024 21:13:24 +0200 Subject: [PATCH] fix speculation --- .../models/flash_causal_lm.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index fc50d69a..ad6123d0 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -2023,8 +2023,8 @@ class FlashCausalLM(Model): batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] batch.speculative_ids = speculative_ids batch.position_ids += accepted_ids - batch.cache_lengths_tensor += batch.input_lengths_tensor - batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32) + batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1 + batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) batch.slot_indices += accepted_ids if prefill and prefill_logprobs: @@ -2202,8 +2202,10 @@ class FlashCausalLM(Model): # processing stopped = False new_input_length = next_chunk_lengths[i] + new_cache_length = cache_length + input_length else: - new_input_length = n_accepted_ids + new_input_length = 1 + new_cache_length = cache_length + input_length + n_accepted_ids - 1 # Append next token to all tokens next_token_texts = [] left = 0 @@ -2315,12 +2317,10 @@ class FlashCausalLM(Model): # Update values index += n_accepted_ids - current_cache_length = cache_length + input_length - batch.cache_lengths[i] = current_cache_length - current_input_length = new_input_length - batch.max_input_length = max(batch.max_input_length, current_input_length) - batch.input_lengths[i] = current_input_length - current_length = current_cache_length + current_input_length + batch.cache_lengths[i] = new_cache_length + batch.max_input_length = max(batch.max_input_length, new_input_length) + batch.input_lengths[i] = new_input_length + current_length = new_cache_length + new_input_length batch.max_current_length = max(batch.max_current_length, current_length) batch.prefix_offsets[i] = prefix_offset