ability to specify tunableop tuned lengths

This commit is contained in:
fxmarty 2024-05-02 16:05:55 +00:00
parent 6c385626eb
commit caf07decf0

View File

@ -834,7 +834,11 @@ class FlashCausalLM(Model):
if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"): if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"):
torch.cuda.tunable.tuning_enable(True) torch.cuda.tunable.tuning_enable(True)
tuning_sequences = list(range(1, 8)) if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS", False):
tuning_sequences = [int(val) for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")]
else:
tuning_sequences = list(range(1, 8))
tunableop_filepath = os.path.join("/data", f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv") tunableop_filepath = os.path.join("/data", f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv")
logger.info(f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} (typical decoding lengths). The picked GEMMs are saved in the file {tunableop_filepath}.") logger.info(f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} (typical decoding lengths). The picked GEMMs are saved in the file {tunableop_filepath}.")