From 5f2e1f0d7ec9b3ad12d4b5d4bf25a4289682f2e2 Mon Sep 17 00:00:00 2001 From: Islam Almersawi Date: Wed, 24 Jul 2024 17:22:44 +0400 Subject: [PATCH] fix: fix missing model id i rocm warmup --- server/text_generation_server/models/flash_causal_lm.py | 5 +++-- server/text_generation_server/models/globals.py | 4 ++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5db62431..f0204ada 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -44,7 +44,7 @@ from text_generation_server.models.globals import ( BLOCK_SIZE, CUDA_GRAPHS, get_adapter_to_index, - MODEL_ID, + get_model_id, ) from text_generation_server.layers.attention import Seqlen from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser @@ -1155,9 +1155,10 @@ class FlashCausalLM(Model): # For seqlen = 1, we dispatch to LLMM1 kernel. tuning_sequences = [2, 3, 4, 5, 6, 7] + model_id = get_model_id() tunableop_filepath = os.path.join( HUGGINGFACE_HUB_CACHE, - f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", + f"tunableop_{model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", ) log_master( diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index ac42df30..25d10ec6 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -37,6 +37,10 @@ def set_model_id(model_id: str): global MODEL_ID MODEL_ID = model_id +def get_model_id(): + global MODEL_ID + return MODEL_ID + # NOTE: eventually we should move this into the router and pass back the # index in all cases.