diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 6eb914f8..37d74279 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1418,7 +1418,7 @@ class FlashCausalLM(Model): ) max_total_tokens = available_blocks else: - max_total_tokens = batch.num_blocks + max_total_tokens = len(batch.input_ids) max_input_tokens = ( batch.num_blocks - 1 if max_input_tokens is None