mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
3489ce7936
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
293 lines
8.2 KiB
Python
293 lines
8.2 KiB
Python
import os
|
|
import torch
|
|
|
|
from loguru import logger
|
|
import math
|
|
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
|
from text_generation_server.utils.flash_attn_triton import triton_attention
|
|
|
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
|
HAS_FLASH_ATTN = False
|
|
HAS_FLASH_ATTN_V2_CUDA = False
|
|
HAS_FLASH_ATTN_V2_ROCM = False
|
|
ROCM_USE_FLASH_ATTN_V2_CK = False
|
|
ROCM_USE_FLASH_ATTN_V2_TRITON = False
|
|
|
|
if SYSTEM == "xpu":
|
|
import intel_extension_for_pytorch as ipex
|
|
|
|
def attention(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
max_s,
|
|
softmax_scale,
|
|
window_size_left=-1,
|
|
):
|
|
if window_size_left <= 0 and window_size_left != -1:
|
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
|
|
|
if window_size_left != -1:
|
|
raise ValueError(
|
|
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
|
)
|
|
return ipex.llm.functional.varlen_attention(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
cu_seqlens,
|
|
max_s,
|
|
max_s,
|
|
0.0,
|
|
softmax_scale,
|
|
False,
|
|
True,
|
|
False,
|
|
None,
|
|
)
|
|
|
|
|
|
if SYSTEM in {"cuda", "rocm"}:
|
|
if not torch.cuda.is_available():
|
|
raise ImportError("CUDA is not available")
|
|
|
|
major, minor = torch.cuda.get_device_capability()
|
|
is_sm75 = major == 7 and minor == 5
|
|
is_sm8x = major == 8 and minor >= 0
|
|
is_sm90 = major == 9 and minor == 0
|
|
is_sm94 = major == 9 and minor == 4
|
|
|
|
if SYSTEM == "rocm":
|
|
if (
|
|
os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true"
|
|
or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1"
|
|
):
|
|
ROCM_USE_FLASH_ATTN_V2_TRITON = True
|
|
logger.info("ROCm: using Flash Attention 2 Triton implementation.")
|
|
else:
|
|
ROCM_USE_FLASH_ATTN_V2_CK = True
|
|
logger.info(
|
|
"ROCm: using Flash Attention 2 Composable Kernel implementation."
|
|
)
|
|
|
|
try:
|
|
try:
|
|
import flash_attn_2_cuda
|
|
except ImportError:
|
|
architecture_suffix = f"-{SYSTEM}"
|
|
raise ImportError(
|
|
"Flash Attention V2 is not installed.\n"
|
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
|
)
|
|
if SYSTEM == "cuda" and not (is_sm8x or is_sm90):
|
|
raise ImportError(
|
|
f"GPU with CUDA capability {major} {minor} is not supported for "
|
|
"Flash Attention V2"
|
|
)
|
|
elif SYSTEM == "rocm" and not (is_sm8x or is_sm90 or is_sm94):
|
|
raise ImportError(
|
|
f"AMD GPU with compute capability {major} {minor} is not supported for "
|
|
"Flash Attention V2"
|
|
)
|
|
HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda"
|
|
HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm"
|
|
except ImportError as e:
|
|
try:
|
|
import flash_attn_cuda
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Flash Attention is not installed.\n"
|
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
"or install flash attention with `cd server && make install install-flash-attention`"
|
|
) from e
|
|
|
|
if SYSTEM == "cuda" and not (is_sm75 or is_sm8x or is_sm90):
|
|
raise ImportError(
|
|
f"GPU with CUDA capability {major} {minor} is not supported"
|
|
) from e
|
|
elif SYSTEM == "rocm":
|
|
for idx in range(torch.cuda.device_count()):
|
|
if "MI210" not in torch.cuda.get_device_name(
|
|
idx
|
|
) and "MI250" not in torch.cuda.get_device_name(idx):
|
|
raise ImportError(
|
|
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
|
)
|
|
|
|
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
|
HAS_FLASH_ATTN = True
|
|
|
|
|
|
if HAS_FLASH_ATTN_V2_CUDA:
|
|
|
|
def attention(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
max_s,
|
|
softmax_scale,
|
|
window_size_left=-1,
|
|
causal=True,
|
|
):
|
|
if window_size_left <= 0 and window_size_left != -1:
|
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
|
return flash_attn_2_cuda.varlen_fwd(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
cu_seqlens,
|
|
None,
|
|
None,
|
|
None,
|
|
max_s,
|
|
max_s,
|
|
0.0,
|
|
softmax_scale,
|
|
False,
|
|
causal,
|
|
window_size_left,
|
|
0,
|
|
False,
|
|
None,
|
|
)
|
|
|
|
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
|
|
|
|
def attention(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
max_s,
|
|
softmax_scale,
|
|
window_size_left=-1,
|
|
causal=True,
|
|
):
|
|
if window_size_left <= 0 and window_size_left != -1:
|
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
|
if window_size_left != -1:
|
|
raise ValueError(
|
|
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
|
)
|
|
|
|
# RoCm flash API does not take the window_size_left and window_size_right arguments.
|
|
return flash_attn_2_cuda.varlen_fwd(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
cu_seqlens,
|
|
max_s,
|
|
max_s,
|
|
0.0,
|
|
softmax_scale,
|
|
False,
|
|
causal,
|
|
False,
|
|
None,
|
|
)
|
|
|
|
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
|
|
|
|
def attention(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
max_s,
|
|
softmax_scale,
|
|
window_size_left=-1,
|
|
causal=True,
|
|
):
|
|
output, _ = triton_attention(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
cu_seqlens,
|
|
max_s,
|
|
max_s,
|
|
causal,
|
|
softmax_scale,
|
|
)
|
|
return output
|
|
|
|
elif HAS_FLASH_ATTN:
|
|
|
|
def attention(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
max_s,
|
|
softmax_scale,
|
|
window_size_left=-1,
|
|
):
|
|
if window_size_left != -1:
|
|
raise NotImplementedError(
|
|
"window_size_left is only available with flash attn v2"
|
|
)
|
|
|
|
# Flash attention v1 requires q, k and v to have the same number of heads
|
|
if k.shape[1] != q.shape[1]:
|
|
# MQA expand
|
|
if k.shape[1] == 1:
|
|
k = k.expand(-1, q.shape[1], -1)
|
|
# Grouped attention reshape
|
|
else:
|
|
original_shape = k.shape
|
|
k = (
|
|
k.unsqueeze(2)
|
|
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
|
.reshape(original_shape[0], -1, original_shape[2])
|
|
)
|
|
if v.shape[1] != q.shape[1]:
|
|
# MQA expand
|
|
if v.shape[1] == 1:
|
|
v = v.expand(-1, q.shape[1], -1)
|
|
# Grouped attention reshape
|
|
else:
|
|
original_shape = v.shape
|
|
v = (
|
|
v.unsqueeze(2)
|
|
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
|
.reshape(original_shape[0], -1, original_shape[2])
|
|
)
|
|
|
|
return flash_attn_cuda.fwd(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
cu_seqlens,
|
|
max_s,
|
|
max_s,
|
|
0.0,
|
|
softmax_scale,
|
|
False,
|
|
True,
|
|
False,
|
|
0,
|
|
None,
|
|
)
|
|
|
|
else:
|
|
raise NotImplementedError("flash attention is not installed")
|