diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index dd52f2db..f1c9c0bf 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -45,7 +45,6 @@ from text_generation_server.pb import generate_pb2 from text_generation_server.utils import ( HeterogeneousNextTokenChooser, StoppingCriteria, - make_tokenizer_optional, is_tokenizer_transparent, pad_next_token_chooser_parameters, ) @@ -636,7 +635,6 @@ class CausalLM(Model): truncation_side="left", trust_remote_code=trust_remote_code, ) - make_tokenizer_optional(tokenizer) # Create model world_size = int(os.getenv("WORLD_SIZE", "1")) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index bc458a91..159e6af1 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -22,6 +22,7 @@ from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.models.globals import set_model_id from text_generation_server.models.globals import set_adapter_to_index from text_generation_server.utils.adapter import AdapterInfo +from text_generation_server.utils.tokens import make_tokenizer_optional try: from text_generation_server.models.pali_gemma import PaliGemmaBatch @@ -101,9 +102,16 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): - max_supported_total_tokens = self.model.warmup(request) - return generate_pb2.WarmupResponse(max_supported_total_tokens=max_supported_total_tokens) + + # W/A for the skip tokenizer path + # We need to call make_tokenizer_optional after the warmup, + # because router is not aware of that feature + make_tokenizer_optional(self.model.tokenizer) + + return generate_pb2.WarmupResponse( + max_supported_total_tokens=max_supported_total_tokens + ) async def Prefill(self, request, context): start = time.time_ns()