apply suggestions

This commit is contained in:
fxmarty 2024-05-16 09:00:37 +00:00
parent b7e98ba635
commit f8d37c14d9
3 changed files with 41 additions and 26 deletions

View File

@ -117,13 +117,15 @@ RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install
ARG GITHUB_TOKEN
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends wget && \
rm -rf /var/lib/apt/lists/* && \
wget --header "Authorization: token ${GITHUB_TOKEN}" https://raw.githubusercontent.com/fxmarty/patched_hipruntime/main/libamdhip64.so.6.2.41130
wget --header "Authorization: token ${GITHUB_TOKEN}" https://raw.githubusercontent.com/fxmarty/patched_hipruntime/main/libamdhip64.so.6
ENV LD_PRELOAD="/libamdhip64.so.6.2.41130"
ENV LD_PRELOAD="/libamdhip64.so.6"
# Set as recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
# Disabled for now as it is currently not stable with ROCm 6.1.
# ENV HIP_FORCE_DEV_KERNARG=1
ENV HIP_FORCE_DEV_KERNARG=1
# On MI300, performances for flash with Triton FA is very competitive (actually better than CK)
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=1
FROM base AS kernel-builder

View File

@ -813,14 +813,15 @@ class FlashCausalLM(Model):
self.device,
)
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
if SYSTEM == "rocm":
if os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"):
torch.cuda.tunable.tuning_enable(True)
if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS", False):
tuning_sequences = [int(val) for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")]
else:
tuning_sequences = list(range(1, 8))
tuning_sequences = [1, 2, 4, 8, 16, 32]
tunableop_filepath = os.path.join("/data", f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv")
@ -837,6 +838,8 @@ class FlashCausalLM(Model):
self.tunableop_warmup(seqlen)
torch.cuda.tunable.write_file(tunableop_filepath)
torch.cuda.tunable.tuning_enable(False)
else:
logger.info("PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.")
if CUDA_GRAPHS:
try:

View File

@ -292,6 +292,16 @@ def _attn_fwd_inner(
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
],
key=["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"],
)