mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 12:02:08 +00:00
Fix warmup with SKIP_TOKENIZER_IN_TGI=true (#266)
This commit is contained in:
parent
7d106477d6
commit
8de110ae9f
@ -45,7 +45,6 @@ from text_generation_server.pb import generate_pb2
|
|||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
HeterogeneousNextTokenChooser,
|
HeterogeneousNextTokenChooser,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
make_tokenizer_optional,
|
|
||||||
is_tokenizer_transparent,
|
is_tokenizer_transparent,
|
||||||
pad_next_token_chooser_parameters,
|
pad_next_token_chooser_parameters,
|
||||||
)
|
)
|
||||||
@ -636,7 +635,6 @@ class CausalLM(Model):
|
|||||||
truncation_side="left",
|
truncation_side="left",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
make_tokenizer_optional(tokenizer)
|
|
||||||
|
|
||||||
# Create model
|
# Create model
|
||||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
@ -22,6 +22,7 @@ from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
|||||||
from text_generation_server.models.globals import set_model_id
|
from text_generation_server.models.globals import set_model_id
|
||||||
from text_generation_server.models.globals import set_adapter_to_index
|
from text_generation_server.models.globals import set_adapter_to_index
|
||||||
from text_generation_server.utils.adapter import AdapterInfo
|
from text_generation_server.utils.adapter import AdapterInfo
|
||||||
|
from text_generation_server.utils.tokens import make_tokenizer_optional
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||||
@ -101,9 +102,16 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||||
|
|
||||||
async def Warmup(self, request, context):
|
async def Warmup(self, request, context):
|
||||||
|
|
||||||
max_supported_total_tokens = self.model.warmup(request)
|
max_supported_total_tokens = self.model.warmup(request)
|
||||||
return generate_pb2.WarmupResponse(max_supported_total_tokens=max_supported_total_tokens)
|
|
||||||
|
# W/A for the skip tokenizer path
|
||||||
|
# We need to call make_tokenizer_optional after the warmup,
|
||||||
|
# because router is not aware of that feature
|
||||||
|
make_tokenizer_optional(self.model.tokenizer)
|
||||||
|
|
||||||
|
return generate_pb2.WarmupResponse(
|
||||||
|
max_supported_total_tokens=max_supported_total_tokens
|
||||||
|
)
|
||||||
|
|
||||||
async def Prefill(self, request, context):
|
async def Prefill(self, request, context):
|
||||||
start = time.time_ns()
|
start = time.time_ns()
|
||||||
|
Loading…
Reference in New Issue
Block a user