From cb8a1680fe45c3b898add8de2347ddeb961d2b04 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 5 Dec 2023 15:29:34 +0000 Subject: [PATCH] Fix. --- server/text_generation_server/models/flash_causal_lm.py | 2 +- server/text_generation_server/models/flash_mistral.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 4aa3637e..855061e5 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -833,7 +833,7 @@ class FlashCausalLM(Model): batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs ) - speculative_length = speculative_ids.shape[1] + speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1] if prefill: if len(batch) > 1 and prefill_logprobs: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index e6ada2c9..e103d9fc 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -21,6 +21,7 @@ from text_generation_server.models.custom_modeling.flash_mistral_modeling import FlashMistralForCausalLM, MistralConfig, ) +from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils import ( initialize_torch_distributed, weight_files, @@ -132,8 +133,7 @@ class FlashMistralBatch(FlashCausalLMBatch): # Paged attention # Remove one as the first token des not have a past - from text_generation_server.models import SPECULATE - speculative_length = SPECULATE + speculative_length = get_speculate() total_tokens = input_length + max_new_tokens - 1 + speculative_length # Needed blocks can not go over SLIDING_WINDOW_BLOCKS