mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix: fix missing model id i rocm warmup
This commit is contained in:
parent
8642250602
commit
5f2e1f0d7e
@ -44,7 +44,7 @@ from text_generation_server.models.globals import (
|
|||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
CUDA_GRAPHS,
|
CUDA_GRAPHS,
|
||||||
get_adapter_to_index,
|
get_adapter_to_index,
|
||||||
MODEL_ID,
|
get_model_id,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
@ -1155,9 +1155,10 @@ class FlashCausalLM(Model):
|
|||||||
# For seqlen = 1, we dispatch to LLMM1 kernel.
|
# For seqlen = 1, we dispatch to LLMM1 kernel.
|
||||||
tuning_sequences = [2, 3, 4, 5, 6, 7]
|
tuning_sequences = [2, 3, 4, 5, 6, 7]
|
||||||
|
|
||||||
|
model_id = get_model_id()
|
||||||
tunableop_filepath = os.path.join(
|
tunableop_filepath = os.path.join(
|
||||||
HUGGINGFACE_HUB_CACHE,
|
HUGGINGFACE_HUB_CACHE,
|
||||||
f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
f"tunableop_{model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
||||||
)
|
)
|
||||||
|
|
||||||
log_master(
|
log_master(
|
||||||
|
@ -37,6 +37,10 @@ def set_model_id(model_id: str):
|
|||||||
global MODEL_ID
|
global MODEL_ID
|
||||||
MODEL_ID = model_id
|
MODEL_ID = model_id
|
||||||
|
|
||||||
|
def get_model_id():
|
||||||
|
global MODEL_ID
|
||||||
|
return MODEL_ID
|
||||||
|
|
||||||
|
|
||||||
# NOTE: eventually we should move this into the router and pass back the
|
# NOTE: eventually we should move this into the router and pass back the
|
||||||
# index in all cases.
|
# index in all cases.
|
||||||
|
Loading…
Reference in New Issue
Block a user