mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
# What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
244 lines
6.8 KiB
Python
244 lines
6.8 KiB
Python
import os
|
|
import torch
|
|
|
|
from loguru import logger
|
|
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
|
|
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
|
HAS_FLASH_ATTN = True
|
|
HAS_FLASH_ATTN_V2_CUDA = False
|
|
HAS_FLASH_ATTN_V2_ROCM = False
|
|
|
|
if SYSTEM == "xpu":
|
|
import intel_extension_for_pytorch as ipex
|
|
|
|
def attention(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
max_s,
|
|
softmax_scale,
|
|
window_size_left=-1,
|
|
):
|
|
if window_size_left <= 0 and window_size_left != -1:
|
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
|
|
|
if window_size_left != -1:
|
|
raise ValueError(
|
|
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
|
)
|
|
return ipex.llm.functional.varlen_attention(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
cu_seqlens,
|
|
max_s,
|
|
max_s,
|
|
0.0,
|
|
softmax_scale,
|
|
False,
|
|
True,
|
|
False,
|
|
None,
|
|
)
|
|
|
|
|
|
if SYSTEM in {"cuda", "rocm"}:
|
|
if not torch.cuda.is_available():
|
|
raise ImportError("CUDA is not available")
|
|
|
|
major, minor = torch.cuda.get_device_capability()
|
|
is_sm75 = major == 7 and minor == 5
|
|
is_sm8x = major == 8 and minor >= 0
|
|
is_sm90 = major == 9 and minor == 0
|
|
|
|
HAS_FLASH_ATTN = False
|
|
HAS_FLASH_ATTN_V2_CUDA = False
|
|
HAS_FLASH_ATTN_V2_ROCM = False
|
|
try:
|
|
try:
|
|
import flash_attn_2_cuda
|
|
except ImportError:
|
|
architecture_suffix = f"-{SYSTEM}"
|
|
raise ImportError(
|
|
"Flash Attention V2 is not installed.\n"
|
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
|
)
|
|
if not (is_sm8x or is_sm90):
|
|
raise ImportError(
|
|
f"GPU with CUDA capability {major} {minor} is not supported for "
|
|
"Flash Attention V2"
|
|
)
|
|
HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda"
|
|
HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm"
|
|
except ImportError as e:
|
|
try:
|
|
import flash_attn_cuda
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Flash Attention is not installed.\n"
|
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
"or install flash attention with `cd server && make install install-flash-attention`"
|
|
) from e
|
|
|
|
if SYSTEM == "cuda" and not (is_sm75 or is_sm8x or is_sm90):
|
|
raise ImportError(
|
|
f"GPU with CUDA capability {major} {minor} is not supported"
|
|
) from e
|
|
elif SYSTEM == "rocm":
|
|
for idx in range(torch.cuda.device_count()):
|
|
if "MI210" not in torch.cuda.get_device_name(
|
|
idx
|
|
) and "MI250" not in torch.cuda.get_device_name(idx):
|
|
raise ImportError(
|
|
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
|
)
|
|
|
|
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
|
HAS_FLASH_ATTN = True
|
|
|
|
|
|
if HAS_FLASH_ATTN_V2_CUDA:
|
|
|
|
def attention(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
max_s,
|
|
softmax_scale,
|
|
window_size_left=-1,
|
|
):
|
|
if window_size_left <= 0 and window_size_left != -1:
|
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
|
return flash_attn_2_cuda.varlen_fwd(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
cu_seqlens,
|
|
None,
|
|
None,
|
|
None,
|
|
max_s,
|
|
max_s,
|
|
0.0,
|
|
softmax_scale,
|
|
False,
|
|
True,
|
|
window_size_left,
|
|
0,
|
|
False,
|
|
None,
|
|
)
|
|
|
|
elif HAS_FLASH_ATTN_V2_ROCM:
|
|
|
|
def attention(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
max_s,
|
|
softmax_scale,
|
|
window_size_left=-1,
|
|
):
|
|
if window_size_left <= 0 and window_size_left != -1:
|
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
|
if window_size_left != -1:
|
|
raise ValueError(
|
|
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
|
)
|
|
|
|
# RoCm flash API does not take the window_size_left and window_size_right arguments.
|
|
return flash_attn_2_cuda.varlen_fwd(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
cu_seqlens,
|
|
max_s,
|
|
max_s,
|
|
0.0,
|
|
softmax_scale,
|
|
False,
|
|
True,
|
|
False,
|
|
None,
|
|
)
|
|
|
|
elif HAS_FLASH_ATTN:
|
|
|
|
def attention(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
max_s,
|
|
softmax_scale,
|
|
window_size_left=-1,
|
|
):
|
|
if window_size_left != -1:
|
|
raise NotImplementedError(
|
|
"window_size_left is only available with flash attn v2"
|
|
)
|
|
|
|
# Flash attention v1 requires q, k and v to have the same number of heads
|
|
if k.shape[1] != q.shape[1]:
|
|
# MQA expand
|
|
if k.shape[1] == 1:
|
|
k = k.expand(-1, q.shape[1], -1)
|
|
# Grouped attention reshape
|
|
else:
|
|
original_shape = k.shape
|
|
k = (
|
|
k.unsqueeze(2)
|
|
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
|
.reshape(original_shape[0], -1, original_shape[2])
|
|
)
|
|
if v.shape[1] != q.shape[1]:
|
|
# MQA expand
|
|
if v.shape[1] == 1:
|
|
v = v.expand(-1, q.shape[1], -1)
|
|
# Grouped attention reshape
|
|
else:
|
|
original_shape = v.shape
|
|
v = (
|
|
v.unsqueeze(2)
|
|
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
|
.reshape(original_shape[0], -1, original_shape[2])
|
|
)
|
|
|
|
return flash_attn_cuda.fwd(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens,
|
|
cu_seqlens,
|
|
max_s,
|
|
max_s,
|
|
0.0,
|
|
softmax_scale,
|
|
False,
|
|
True,
|
|
False,
|
|
0,
|
|
None,
|
|
)
|
|
|
|
else:
|
|
raise NotImplementedError("flash attention is not installed")
|