diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index c15e6e4e..70b32e4a 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -465,6 +465,8 @@ class CausalLMBatch(Batch): requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)] max_input_length = max(r.data.truncate for r in requests) + if max_input_length < PAD_SEQUENCE_TO_MULTIPLE_OF: + max_input_length = PAD_SEQUENCE_TO_MULTIPLE_OF max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests) # TODO: Add support for sparse batches