From 0ca83be88345a6b1936a788c94b7263b6d8dc9ad Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 19 Apr 2024 11:11:26 +0000 Subject: [PATCH] WIP debug Triton FA2 --- Dockerfile_amd | 10 +++++++++ .../utils/flash_attn.py | 22 ++++++++++++++++++- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index 7d7e1913..c532bae9 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -91,6 +91,13 @@ RUN pip install torch numpy --index-url https://download.pytorch.org/whl/test/ro FROM base AS kernel-builder +# Build Triton +FROM kernel-builder as triton-builder +WORKDIR /usr/src + +COPY server/Makefile-triton Makefile +RUN make build-triton-rocm + # Build vllm kernels FROM kernel-builder AS vllm-builder WORKDIR /usr/src @@ -136,6 +143,9 @@ ENV HUGGINGFACE_HUB_CACHE=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ PORT=80 +# Copy builds artifacts from triton builder +COPY --from=triton-builder /usr/src/triton/python/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages + # Copy builds artifacts from vllm builder COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 87a9311b..71577306 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -4,6 +4,7 @@ import torch from loguru import logger from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM +from text_generation_server.utils.flash_attn_triton import triton_attention if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") @@ -193,6 +194,25 @@ def attention( None, ) elif IS_ROCM_SYSTEM and ROCM_USE_FLASH_ATTN_V2_TRITON: - raise NotImplementedError("TODO") + logger.info(f"q shape {q.shape} {q.dtype} {q.is_contiguous()}") + logger.info(f"k shape {k.shape} {k.dtype} {k.is_contiguous()}") + logger.info(f"v shape {v.shape} {v.dtype} {v.is_contiguous()}") + logger.info(f"cu_seqlens {cu_seqlens}") + logger.info(f"max_s {max_s}") + output, _ = triton_attention( + q, + k, + v, + None, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + True, + softmax_scale, + ) + logger.info(f"output shape {output.shape} {output.dtype}") + logger.info(f"output {output}") + return output else: raise NotImplementedError(f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})")