text-generation-inference/server/text_generation_server/utils/flash_attn.py

238 lines
7.8 KiB
Python
Raw Normal View History

import os
import torch
from loguru import logger
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
2024-04-19 11:11:26 +00:00
from text_generation_server.utils.flash_attn_triton import triton_attention
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if not torch.cuda.is_available():
raise ImportError("CUDA is not available")
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
2024-04-18 23:31:28 +00:00
is_sm94 = major == 9 and minor == 4
HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
2024-04-19 10:11:39 +00:00
ROCM_USE_FLASH_ATTN_V2_CK = False
ROCM_USE_FLASH_ATTN_V2_TRITON = False
if IS_ROCM_SYSTEM:
if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true":
ROCM_USE_FLASH_ATTN_V2_TRITON = True
2024-04-19 11:57:16 +00:00
logger.info("ROCm: using Flash Attention 2 Triton implementation.")
2024-04-19 10:11:39 +00:00
else:
ROCM_USE_FLASH_ATTN_V2_CK = True
2024-04-19 11:57:16 +00:00
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
2024-04-19 10:11:39 +00:00
try:
try:
import flash_attn_2_cuda
except ImportError:
architecture_suffix = ""
if IS_CUDA_SYSTEM:
architecture_suffix = "-cuda"
elif IS_ROCM_SYSTEM:
architecture_suffix = "-rocm"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
2024-04-18 23:31:28 +00:00
if IS_CUDA_SYSTEM and not (is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported for "
"Flash Attention V2"
2024-04-18 23:31:28 +00:00
)
elif IS_ROCM_SYSTEM and not (is_sm8x or is_sm90 or is_sm94):
raise ImportError(
f"AMD GPU with compute capability {major} {minor} is not supported for "
"Flash Attention V2"
)
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
except ImportError as e:
try:
import flash_attn_cuda
except ImportError:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
elif IS_ROCM_SYSTEM:
for idx in range(torch.cuda.device_count()):
2023-12-11 13:49:52 +00:00
if "MI210" not in torch.cuda.get_device_name(
idx
) and "MI250" not in torch.cuda.get_device_name(idx):
raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
)
logger.warning(f"Unable to use Flash Attention V2: {e}")
HAS_FLASH_ATTN = True
2024-04-23 15:05:39 +00:00
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
total_tokens, num_key_value_heads, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
2024-04-23 15:05:39 +00:00
hidden_states = hidden_states[:, :, None, :].expand(
total_tokens, num_key_value_heads, n_rep, head_dim
)
return hidden_states.reshape(total_tokens, num_key_value_heads * n_rep, head_dim)
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
2023-09-28 07:55:47 +00:00
window_size_left=-1,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
2024-04-19 10:11:39 +00:00
if IS_CUDA_SYSTEM and HAS_FLASH_ATTN_V2_CUDA:
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
None,
None,
None,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
2023-09-28 07:55:47 +00:00
window_size_left,
0,
False,
None,
)
2024-04-19 10:11:39 +00:00
elif IS_CUDA_SYSTEM and HAS_FLASH_ATTN:
Fix window_size_left for flash attention v1 (#1089) This fixes flash attention v1 which was always NotImplementedError("window_size_left is only available with flash attn v2"). Currently flash_llama_modeling.py doesn't override the default value of window_size_left when calling attention(..) (line 282). This means that window_size_left will always be the default of -1, but flash attention v1 throws an exception if `window_size_left != 0`. To fix this, we should be checking `window_size_left != -1` before throwing the NotImplementedError. Fixes #1084 ## Before submitting - [x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @OlivierDehaene OR @Narsil
2023-10-02 18:53:14 +00:00
if window_size_left != -1:
2023-09-28 07:55:47 +00:00
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
return flash_attn_cuda.fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
0,
None,
)
2024-04-19 10:11:39 +00:00
elif IS_ROCM_SYSTEM and HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
if window_size_left != -1:
raise ValueError(
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
2024-04-19 10:11:39 +00:00
# RoCm flash API does not take the window_size_left and window_size_right arguments.
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None,
)
elif IS_ROCM_SYSTEM and ROCM_USE_FLASH_ATTN_V2_TRITON:
# NOTE: The Triton kernel silently outputs wrong results when using MQA/GQA and not
# repeating.
# TODO: just a sketch. Kind of need to abstract this `attention` function to enable some customization and pass those - let's sync with Nicolas for which implem he'd like
num_heads = q.shape[1]
num_kv_heads = k.shape[1]
if num_kv_heads != num_heads:
# Interleave for MQA workaround.
k = repeat_kv(k, num_heads // num_kv_heads)
v = repeat_kv(v, num_heads // num_kv_heads)
2024-04-19 11:11:26 +00:00
output, _ = triton_attention(
q,
k,
v,
2024-04-19 11:23:27 +00:00
out,
2024-04-19 11:11:26 +00:00
cu_seqlens,
cu_seqlens,
max_s,
max_s,
True,
softmax_scale,
)
return output
2024-04-19 10:11:39 +00:00
else:
2024-04-23 15:05:39 +00:00
raise NotImplementedError(
f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})"
)