mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix: remove global model id
This commit is contained in:
parent
c2413a0153
commit
2ce1476e52
@ -43,9 +43,7 @@ from text_generation_server.models.globals import (
|
|||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
CUDA_GRAPHS,
|
CUDA_GRAPHS,
|
||||||
get_adapter_to_index,
|
get_adapter_to_index,
|
||||||
MODEL_ID,
|
|
||||||
)
|
)
|
||||||
import text_generation_server.models.globals as globals_vars
|
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||||
@ -1157,7 +1155,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
tunableop_filepath = os.path.join(
|
tunableop_filepath = os.path.join(
|
||||||
HUGGINGFACE_HUB_CACHE,
|
HUGGINGFACE_HUB_CACHE,
|
||||||
f"tunableop_{globals_vars.MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
||||||
)
|
)
|
||||||
|
|
||||||
log_master(
|
log_master(
|
||||||
|
@ -29,14 +29,6 @@ if cuda_graphs is not None:
|
|||||||
|
|
||||||
CUDA_GRAPHS = cuda_graphs
|
CUDA_GRAPHS = cuda_graphs
|
||||||
|
|
||||||
# This is overridden at model loading.
|
|
||||||
MODEL_ID = None
|
|
||||||
|
|
||||||
|
|
||||||
def set_model_id(model_id: str):
|
|
||||||
global MODEL_ID
|
|
||||||
MODEL_ID = model_id
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE: eventually we should move this into the router and pass back the
|
# NOTE: eventually we should move this into the router and pass back the
|
||||||
# index in all cases.
|
# index in all cases.
|
||||||
|
@ -30,7 +30,7 @@ except (ImportError, NotImplementedError):
|
|||||||
|
|
||||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
from text_generation_server.models.globals import set_model_id, set_adapter_to_index
|
from text_generation_server.models.globals import set_adapter_to_index
|
||||||
|
|
||||||
|
|
||||||
class SignalHandler:
|
class SignalHandler:
|
||||||
@ -271,7 +271,6 @@ def serve(
|
|||||||
while signal_handler.KEEP_PROCESSING:
|
while signal_handler.KEEP_PROCESSING:
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
set_model_id(model_id)
|
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
serve_inner(
|
serve_inner(
|
||||||
model_id,
|
model_id,
|
||||||
|
Loading…
Reference in New Issue
Block a user