mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 12:02:08 +00:00
Fix warmup with SKIP_TOKENIZER_IN_TGI=true (#266)
This commit is contained in:
parent
7d106477d6
commit
8de110ae9f
@ -45,7 +45,6 @@ from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import (
|
||||
HeterogeneousNextTokenChooser,
|
||||
StoppingCriteria,
|
||||
make_tokenizer_optional,
|
||||
is_tokenizer_transparent,
|
||||
pad_next_token_chooser_parameters,
|
||||
)
|
||||
@ -636,7 +635,6 @@ class CausalLM(Model):
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
make_tokenizer_optional(tokenizer)
|
||||
|
||||
# Create model
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
@ -22,6 +22,7 @@ from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
from text_generation_server.models.globals import set_model_id
|
||||
from text_generation_server.models.globals import set_adapter_to_index
|
||||
from text_generation_server.utils.adapter import AdapterInfo
|
||||
from text_generation_server.utils.tokens import make_tokenizer_optional
|
||||
|
||||
try:
|
||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||
@ -101,9 +102,16 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||
|
||||
async def Warmup(self, request, context):
|
||||
|
||||
max_supported_total_tokens = self.model.warmup(request)
|
||||
return generate_pb2.WarmupResponse(max_supported_total_tokens=max_supported_total_tokens)
|
||||
|
||||
# W/A for the skip tokenizer path
|
||||
# We need to call make_tokenizer_optional after the warmup,
|
||||
# because router is not aware of that feature
|
||||
make_tokenizer_optional(self.model.tokenizer)
|
||||
|
||||
return generate_pb2.WarmupResponse(
|
||||
max_supported_total_tokens=max_supported_total_tokens
|
||||
)
|
||||
|
||||
async def Prefill(self, request, context):
|
||||
start = time.time_ns()
|
||||
|
Loading…
Reference in New Issue
Block a user