mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
WIP debug Triton FA2
This commit is contained in:
parent
47e522a66a
commit
0ca83be883
@ -91,6 +91,13 @@ RUN pip install torch numpy --index-url https://download.pytorch.org/whl/test/ro
|
|||||||
|
|
||||||
FROM base AS kernel-builder
|
FROM base AS kernel-builder
|
||||||
|
|
||||||
|
# Build Triton
|
||||||
|
FROM kernel-builder as triton-builder
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
COPY server/Makefile-triton Makefile
|
||||||
|
RUN make build-triton-rocm
|
||||||
|
|
||||||
# Build vllm kernels
|
# Build vllm kernels
|
||||||
FROM kernel-builder AS vllm-builder
|
FROM kernel-builder AS vllm-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
@ -136,6 +143,9 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
|
|||||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
PORT=80
|
PORT=80
|
||||||
|
|
||||||
|
# Copy builds artifacts from triton builder
|
||||||
|
COPY --from=triton-builder /usr/src/triton/python/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
# Copy builds artifacts from vllm builder
|
# Copy builds artifacts from vllm builder
|
||||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ import torch
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||||
|
from text_generation_server.utils.flash_attn_triton import triton_attention
|
||||||
|
|
||||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||||
@ -193,6 +194,25 @@ def attention(
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
elif IS_ROCM_SYSTEM and ROCM_USE_FLASH_ATTN_V2_TRITON:
|
elif IS_ROCM_SYSTEM and ROCM_USE_FLASH_ATTN_V2_TRITON:
|
||||||
raise NotImplementedError("TODO")
|
logger.info(f"q shape {q.shape} {q.dtype} {q.is_contiguous()}")
|
||||||
|
logger.info(f"k shape {k.shape} {k.dtype} {k.is_contiguous()}")
|
||||||
|
logger.info(f"v shape {v.shape} {v.dtype} {v.is_contiguous()}")
|
||||||
|
logger.info(f"cu_seqlens {cu_seqlens}")
|
||||||
|
logger.info(f"max_s {max_s}")
|
||||||
|
output, _ = triton_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
None,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
True,
|
||||||
|
softmax_scale,
|
||||||
|
)
|
||||||
|
logger.info(f"output shape {output.shape} {output.dtype}")
|
||||||
|
logger.info(f"output {output}")
|
||||||
|
return output
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})")
|
raise NotImplementedError(f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})")
|
||||||
|
Loading…
Reference in New Issue
Block a user