WIP debug Triton FA2

This commit is contained in:
fxmarty 2024-04-19 11:11:26 +00:00
parent 47e522a66a
commit 0ca83be883
2 changed files with 31 additions and 1 deletions

View File

@ -91,6 +91,13 @@ RUN pip install torch numpy --index-url https://download.pytorch.org/whl/test/ro
FROM base AS kernel-builder
# Build Triton
FROM kernel-builder as triton-builder
WORKDIR /usr/src
COPY server/Makefile-triton Makefile
RUN make build-triton-rocm
# Build vllm kernels
FROM kernel-builder AS vllm-builder
WORKDIR /usr/src
@ -136,6 +143,9 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80
# Copy builds artifacts from triton builder
COPY --from=triton-builder /usr/src/triton/python/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages

View File

@ -4,6 +4,7 @@ import torch
from loguru import logger
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
from text_generation_server.utils.flash_attn_triton import triton_attention
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
@ -193,6 +194,25 @@ def attention(
None,
)
elif IS_ROCM_SYSTEM and ROCM_USE_FLASH_ATTN_V2_TRITON:
raise NotImplementedError("TODO")
logger.info(f"q shape {q.shape} {q.dtype} {q.is_contiguous()}")
logger.info(f"k shape {k.shape} {k.dtype} {k.is_contiguous()}")
logger.info(f"v shape {v.shape} {v.dtype} {v.is_contiguous()}")
logger.info(f"cu_seqlens {cu_seqlens}")
logger.info(f"max_s {max_s}")
output, _ = triton_attention(
q,
k,
v,
None,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
True,
softmax_scale,
)
logger.info(f"output shape {output.shape} {output.dtype}")
logger.info(f"output {output}")
return output
else:
raise NotImplementedError(f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})")