From 95ff267043c1297ea9f5e7e3b8f50c6696a123a9 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 29 Jul 2024 16:30:53 +0000 Subject: [PATCH] fix: remove global model id --- server/text_generation_server/models/flash_causal_lm.py | 4 +--- server/text_generation_server/models/globals.py | 9 --------- server/text_generation_server/server.py | 3 +-- 3 files changed, 2 insertions(+), 14 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 1e871e11..36bb2662 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -43,9 +43,7 @@ from text_generation_server.models.globals import ( BLOCK_SIZE, CUDA_GRAPHS, get_adapter_to_index, - MODEL_ID, ) -import text_generation_server.models.globals as globals_vars from text_generation_server.layers.attention import Seqlen from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION @@ -1157,7 +1155,7 @@ class FlashCausalLM(Model): tunableop_filepath = os.path.join( HUGGINGFACE_HUB_CACHE, - f"tunableop_{globals_vars.MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", + f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", ) log_master( diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index ac42df30..8d2431db 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -29,15 +29,6 @@ if cuda_graphs is not None: CUDA_GRAPHS = cuda_graphs -# This is overridden at model loading. -MODEL_ID = None - - -def set_model_id(model_id: str): - global MODEL_ID - MODEL_ID = model_id - - # NOTE: eventually we should move this into the router and pass back the # index in all cases. ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 22bd759f..b92ab572 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -30,7 +30,7 @@ except (ImportError, NotImplementedError): from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor -from text_generation_server.models.globals import set_model_id, set_adapter_to_index +from text_generation_server.models.globals import set_adapter_to_index class SignalHandler: @@ -271,7 +271,6 @@ def serve( while signal_handler.KEEP_PROCESSING: await asyncio.sleep(0.5) - set_model_id(model_id) asyncio.run( serve_inner( model_id,