diff --git a/server/requirements.txt b/server/requirements.txt index d941a894..a940574f 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -74,7 +74,6 @@ six==1.16.0 ; python_version >= "3.9" and python_version < "3.13" sympy==1.12.1 ; python_version >= "3.9" and python_version < "3.13" threadpoolctl==3.5.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13" -torch==2.4.0a0+git74cd574 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13" transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13" transformers[sentencepiece]==4.45.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 999c0bb6..8b3df5e0 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -494,9 +494,12 @@ class CausalLMBatch(Batch): inputs.append(concat_text_chunks(r.input_chunks.chunks)) top_n_tokens.append(r.top_n_tokens) max_truncation = max(max_truncation, r.truncate) - - max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests) + max_input_length = max_truncation + if max_input_length < PAD_SEQUENCE_TO_MULTIPLE_OF: + max_input_length = PAD_SEQUENCE_TO_MULTIPLE_OF + max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests) + # TODO: by tokenizing all inputs at once we loose information on actual input lengths # this means that we cannot shift inputs to the left after a long input sequence # was filtered out