diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 05bad924..7ebe3dea 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1618,7 +1618,7 @@ class FlashCausalLM(Model): input_lengths=input_lengths, cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, - max_q=max_s, + max_q=batch.max_input_length, max_k=batch.max_current_length, ) logits, speculative_logits = self.model.forward( @@ -2236,8 +2236,6 @@ class FlashCausalLM(Model): use_prefill_with_paged_kv_state, ) - # has_cache_lengths = any(cache_length > 0 for cache_length in cache_lengths) - if cu_seqlen_prefill is not None: return use_prefill_with_paged_kv_state( state=(