mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
ability to specify tunableop tuned lengths
This commit is contained in:
parent
6c385626eb
commit
caf07decf0
@ -834,7 +834,11 @@ class FlashCausalLM(Model):
|
|||||||
if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"):
|
if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"):
|
||||||
torch.cuda.tunable.tuning_enable(True)
|
torch.cuda.tunable.tuning_enable(True)
|
||||||
|
|
||||||
tuning_sequences = list(range(1, 8))
|
if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS", False):
|
||||||
|
tuning_sequences = [int(val) for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")]
|
||||||
|
else:
|
||||||
|
tuning_sequences = list(range(1, 8))
|
||||||
|
|
||||||
tunableop_filepath = os.path.join("/data", f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv")
|
tunableop_filepath = os.path.join("/data", f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv")
|
||||||
|
|
||||||
logger.info(f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} (typical decoding lengths). The picked GEMMs are saved in the file {tunableop_filepath}.")
|
logger.info(f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} (typical decoding lengths). The picked GEMMs are saved in the file {tunableop_filepath}.")
|
||||||
|
Loading…
Reference in New Issue
Block a user