From caf07decf0d058e6f574c5c4d07cc5f4576cc814 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 2 May 2024 16:05:55 +0000 Subject: [PATCH] ability to specify tunableop tuned lengths --- server/text_generation_server/models/flash_causal_lm.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 311242c6..a17c47ab 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -834,7 +834,11 @@ class FlashCausalLM(Model): if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"): torch.cuda.tunable.tuning_enable(True) - tuning_sequences = list(range(1, 8)) + if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS", False): + tuning_sequences = [int(val) for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")] + else: + tuning_sequences = list(range(1, 8)) + tunableop_filepath = os.path.join("/data", f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv") logger.info(f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} (typical decoding lengths). The picked GEMMs are saved in the file {tunableop_filepath}.")