From d73c5c634dea446ca53b4a9ca3e1e6d28961c2a6 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 9 Oct 2024 19:39:14 +0200 Subject: [PATCH] max input length --- server/text_generation_server/models/flash_causal_lm.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 05bad924..7ebe3dea 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1618,7 +1618,7 @@ class FlashCausalLM(Model): input_lengths=input_lengths, cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, - max_q=max_s, + max_q=batch.max_input_length, max_k=batch.max_current_length, ) logits, speculative_logits = self.model.forward( @@ -2236,8 +2236,6 @@ class FlashCausalLM(Model): use_prefill_with_paged_kv_state, ) - # has_cache_lengths = any(cache_length > 0 for cache_length in cache_lengths) - if cu_seqlen_prefill is not None: return use_prefill_with_paged_kv_state( state=(