mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
WIP debug Triton FA2
This commit is contained in:
parent
47e522a66a
commit
0ca83be883
@ -91,6 +91,13 @@ RUN pip install torch numpy --index-url https://download.pytorch.org/whl/test/ro
|
||||
|
||||
FROM base AS kernel-builder
|
||||
|
||||
# Build Triton
|
||||
FROM kernel-builder as triton-builder
|
||||
WORKDIR /usr/src
|
||||
|
||||
COPY server/Makefile-triton Makefile
|
||||
RUN make build-triton-rocm
|
||||
|
||||
# Build vllm kernels
|
||||
FROM kernel-builder AS vllm-builder
|
||||
WORKDIR /usr/src
|
||||
@ -136,6 +143,9 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||
PORT=80
|
||||
|
||||
# Copy builds artifacts from triton builder
|
||||
COPY --from=triton-builder /usr/src/triton/python/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# Copy builds artifacts from vllm builder
|
||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
|
@ -4,6 +4,7 @@ import torch
|
||||
from loguru import logger
|
||||
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||
from text_generation_server.utils.flash_attn_triton import triton_attention
|
||||
|
||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||
@ -193,6 +194,25 @@ def attention(
|
||||
None,
|
||||
)
|
||||
elif IS_ROCM_SYSTEM and ROCM_USE_FLASH_ATTN_V2_TRITON:
|
||||
raise NotImplementedError("TODO")
|
||||
logger.info(f"q shape {q.shape} {q.dtype} {q.is_contiguous()}")
|
||||
logger.info(f"k shape {k.shape} {k.dtype} {k.is_contiguous()}")
|
||||
logger.info(f"v shape {v.shape} {v.dtype} {v.is_contiguous()}")
|
||||
logger.info(f"cu_seqlens {cu_seqlens}")
|
||||
logger.info(f"max_s {max_s}")
|
||||
output, _ = triton_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
None,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
True,
|
||||
softmax_scale,
|
||||
)
|
||||
logger.info(f"output shape {output.shape} {output.dtype}")
|
||||
logger.info(f"output {output}")
|
||||
return output
|
||||
else:
|
||||
raise NotImplementedError(f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})")
|
||||
|
Loading…
Reference in New Issue
Block a user