diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 37d74279..ca45020c 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1418,9 +1418,9 @@ class FlashCausalLM(Model): ) max_total_tokens = available_blocks else: - max_total_tokens = len(batch.input_ids) + max_total_tokens = sum(len(input_ids) for input_ids in batch.input_ids) max_input_tokens = ( - batch.num_blocks - 1 + max_total_tokens - 1 if max_input_tokens is None else max_input_tokens )