mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix
This commit is contained in:
parent
bc2f351980
commit
d186b13c59
@ -1,4 +1,4 @@
|
||||
flash_att_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc
|
||||
flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc
|
||||
|
||||
flash-attention-v2:
|
||||
# Clone flash attention
|
||||
@ -6,7 +6,7 @@ flash-attention-v2:
|
||||
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2
|
||||
|
||||
build-flash-attention-v2: flash-attention-v2
|
||||
cd flash-attention-v2 && git fetch && git checkout $(flash_att_commit)
|
||||
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit)
|
||||
cd flash-attention-v2 && python setup.py build
|
||||
|
||||
install-flash-attention-v2: build-flash-attention-v2
|
||||
|
@ -55,10 +55,8 @@ try:
|
||||
FlashSantacoderSharded,
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
logger.opt(exception=True).warning(
|
||||
"Could not import Flash Attention enabled models"
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
||||
FLASH_ATTENTION = False
|
||||
|
||||
if FLASH_ATTENTION:
|
||||
|
@ -188,7 +188,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
@ -308,7 +308,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
|
@ -1,6 +1,8 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
from loguru import logger
|
||||
|
||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||
|
||||
@ -18,7 +20,8 @@ try:
|
||||
try:
|
||||
import flash_attn_2_cuda
|
||||
except ImportError:
|
||||
raise ImportError("Flash Attention V2 is not installed.\n"
|
||||
raise ImportError(
|
||||
"Flash Attention V2 is not installed.\n"
|
||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||
"or install flash attention v2 with `cd server && make install install-flash-attention-v2`"
|
||||
)
|
||||
@ -32,7 +35,8 @@ except ImportError as e:
|
||||
try:
|
||||
import flash_attn_cuda
|
||||
except ImportError:
|
||||
raise ImportError("Flash Attention is not installed.\n"
|
||||
raise ImportError(
|
||||
"Flash Attention is not installed.\n"
|
||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||
) from e
|
||||
@ -41,6 +45,7 @@ except ImportError as e:
|
||||
raise ImportError(
|
||||
f"GPU with CUDA capability {major} {minor} is not supported"
|
||||
) from e
|
||||
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
||||
HAS_FLASH_ATTN = True
|
||||
|
||||
|
||||
@ -76,21 +81,27 @@ def attention(
|
||||
if k.shape[1] != q.shape[1]:
|
||||
# MQA expand
|
||||
if k.shape[1] == 1:
|
||||
k = k.expand(-1, q.shape[1], -1, -1)
|
||||
k = k.expand(-1, q.shape[1], -1)
|
||||
# Grouped attention reshape
|
||||
else:
|
||||
original_shape = k.shape
|
||||
k = k.unsqueeze(2).expand(-1, -1, q.shape[1], -1, -1) \
|
||||
.reshape(original_shape[0], -1, original_shape[1], original_shape[2])
|
||||
k = (
|
||||
k.unsqueeze(2)
|
||||
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
||||
.reshape(original_shape[0], -1, original_shape[2])
|
||||
)
|
||||
if v.shape[1] != q.shape[1]:
|
||||
# MQA expand
|
||||
if v.shape[1] == 1:
|
||||
v = v.expand(-1, q.shape[1], -1, -1)
|
||||
v = v.expand(-1, q.shape[1], -1)
|
||||
# Grouped attention reshape
|
||||
else:
|
||||
original_shape = v.shape
|
||||
v = v.unsqueeze(2).expand(-1, -1, q.shape[1], -1, -1) \
|
||||
.reshape(original_shape[0], -1, original_shape[1], original_shape[2])
|
||||
v = (
|
||||
v.unsqueeze(2)
|
||||
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
||||
.reshape(original_shape[0], -1, original_shape[2])
|
||||
)
|
||||
|
||||
return flash_attn_cuda.fwd(
|
||||
q,
|
||||
|
Loading…
Reference in New Issue
Block a user