Fix warmup with SKIP_TOKENIZER_IN_TGI=true (#266)

This commit is contained in:
Karol Damaszke 2025-01-21 10:09:49 +01:00 committed by GitHub
parent 7d106477d6
commit 8de110ae9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 4 deletions

View File

@ -45,7 +45,6 @@ from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
HeterogeneousNextTokenChooser,
StoppingCriteria,
make_tokenizer_optional,
is_tokenizer_transparent,
pad_next_token_chooser_parameters,
)
@ -636,7 +635,6 @@ class CausalLM(Model):
truncation_side="left",
trust_remote_code=trust_remote_code,
)
make_tokenizer_optional(tokenizer)
# Create model
world_size = int(os.getenv("WORLD_SIZE", "1"))

View File

@ -22,6 +22,7 @@ from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.globals import set_model_id
from text_generation_server.models.globals import set_adapter_to_index
from text_generation_server.utils.adapter import AdapterInfo
from text_generation_server.utils.tokens import make_tokenizer_optional
try:
from text_generation_server.models.pali_gemma import PaliGemmaBatch
@ -101,9 +102,16 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context):
max_supported_total_tokens = self.model.warmup(request)
return generate_pb2.WarmupResponse(max_supported_total_tokens=max_supported_total_tokens)
# W/A for the skip tokenizer path
# We need to call make_tokenizer_optional after the warmup,
# because router is not aware of that feature
make_tokenizer_optional(self.model.tokenizer)
return generate_pb2.WarmupResponse(
max_supported_total_tokens=max_supported_total_tokens
)
async def Prefill(self, request, context):
start = time.time_ns()