fix: fix missing model id i rocm warmup

This commit is contained in:
Islam Almersawi 2024-07-24 17:22:44 +04:00
parent 8642250602
commit 5f2e1f0d7e
2 changed files with 7 additions and 2 deletions

View File

@ -44,7 +44,7 @@ from text_generation_server.models.globals import (
BLOCK_SIZE, BLOCK_SIZE,
CUDA_GRAPHS, CUDA_GRAPHS,
get_adapter_to_index, get_adapter_to_index,
MODEL_ID, get_model_id,
) )
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
@ -1155,9 +1155,10 @@ class FlashCausalLM(Model):
# For seqlen = 1, we dispatch to LLMM1 kernel. # For seqlen = 1, we dispatch to LLMM1 kernel.
tuning_sequences = [2, 3, 4, 5, 6, 7] tuning_sequences = [2, 3, 4, 5, 6, 7]
model_id = get_model_id()
tunableop_filepath = os.path.join( tunableop_filepath = os.path.join(
HUGGINGFACE_HUB_CACHE, HUGGINGFACE_HUB_CACHE,
f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", f"tunableop_{model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
) )
log_master( log_master(

View File

@ -37,6 +37,10 @@ def set_model_id(model_id: str):
global MODEL_ID global MODEL_ID
MODEL_ID = model_id MODEL_ID = model_id
def get_model_id():
global MODEL_ID
return MODEL_ID
# NOTE: eventually we should move this into the router and pass back the # NOTE: eventually we should move this into the router and pass back the
# index in all cases. # index in all cases.