From f8d37c14d95ea9dae38a908d1cee831e608716c4 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 16 May 2024 09:00:37 +0000 Subject: [PATCH] apply suggestions --- Dockerfile_amd | 10 ++-- .../models/flash_causal_lm.py | 47 ++++++++++--------- .../utils/flash_attn_triton.py | 10 ++++ 3 files changed, 41 insertions(+), 26 deletions(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index 1d19c921..68561bdf 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -117,13 +117,15 @@ RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install ARG GITHUB_TOKEN RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends wget && \ rm -rf /var/lib/apt/lists/* && \ - wget --header "Authorization: token ${GITHUB_TOKEN}" https://raw.githubusercontent.com/fxmarty/patched_hipruntime/main/libamdhip64.so.6.2.41130 + wget --header "Authorization: token ${GITHUB_TOKEN}" https://raw.githubusercontent.com/fxmarty/patched_hipruntime/main/libamdhip64.so.6 -ENV LD_PRELOAD="/libamdhip64.so.6.2.41130" +ENV LD_PRELOAD="/libamdhip64.so.6" # Set as recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm -# Disabled for now as it is currently not stable with ROCm 6.1. -# ENV HIP_FORCE_DEV_KERNARG=1 +ENV HIP_FORCE_DEV_KERNARG=1 + +# On MI300, performances for flash with Triton FA is very competitive (actually better than CK) +ENV ROCM_USE_FLASH_ATTN_V2_TRITON=1 FROM base AS kernel-builder diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a7f3470d..69fca1b5 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -813,30 +813,33 @@ class FlashCausalLM(Model): self.device, ) - if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): - if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"): - torch.cuda.tunable.tuning_enable(True) + if SYSTEM == "rocm": + if os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): + if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"): + torch.cuda.tunable.tuning_enable(True) - if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS", False): - tuning_sequences = [int(val) for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")] + if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS", False): + tuning_sequences = [int(val) for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")] + else: + tuning_sequences = [1, 2, 4, 8, 16, 32] + + tunableop_filepath = os.path.join("/data", f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv") + + logger.info(f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} (typical decoding lengths). The picked GEMMs are saved in the file {tunableop_filepath}.") + + if os.path.isfile(tunableop_filepath): + logger.info(f"The file {tunableop_filepath} already exists and will be reused.") + torch.cuda.tunable.read_file(tunableop_filepath) + + os.makedirs("/data", exist_ok=True) + + for seqlen in tuning_sequences: + logger.info(f"Warming up TunableOp for seqlen={seqlen}") + self.tunableop_warmup(seqlen) + torch.cuda.tunable.write_file(tunableop_filepath) + torch.cuda.tunable.tuning_enable(False) else: - tuning_sequences = list(range(1, 8)) - - tunableop_filepath = os.path.join("/data", f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv") - - logger.info(f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} (typical decoding lengths). The picked GEMMs are saved in the file {tunableop_filepath}.") - - if os.path.isfile(tunableop_filepath): - logger.info(f"The file {tunableop_filepath} already exists and will be reused.") - torch.cuda.tunable.read_file(tunableop_filepath) - - os.makedirs("/data", exist_ok=True) - - for seqlen in tuning_sequences: - logger.info(f"Warming up TunableOp for seqlen={seqlen}") - self.tunableop_warmup(seqlen) - torch.cuda.tunable.write_file(tunableop_filepath) - torch.cuda.tunable.tuning_enable(False) + logger.info("PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.") if CUDA_GRAPHS: try: diff --git a/server/text_generation_server/utils/flash_attn_triton.py b/server/text_generation_server/utils/flash_attn_triton.py index 8378c06e..9167b1f4 100644 --- a/server/text_generation_server/utils/flash_attn_triton.py +++ b/server/text_generation_server/utils/flash_attn_triton.py @@ -292,6 +292,16 @@ def _attn_fwd_inner( num_stages=1, num_warps=4, ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 1, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), ], key=["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"], )