mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix
This commit is contained in:
parent
bc2f351980
commit
d186b13c59
@ -1,4 +1,4 @@
|
|||||||
flash_att_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc
|
flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc
|
||||||
|
|
||||||
flash-attention-v2:
|
flash-attention-v2:
|
||||||
# Clone flash attention
|
# Clone flash attention
|
||||||
@ -6,7 +6,7 @@ flash-attention-v2:
|
|||||||
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2
|
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2
|
||||||
|
|
||||||
build-flash-attention-v2: flash-attention-v2
|
build-flash-attention-v2: flash-attention-v2
|
||||||
cd flash-attention-v2 && git fetch && git checkout $(flash_att_commit)
|
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit)
|
||||||
cd flash-attention-v2 && python setup.py build
|
cd flash-attention-v2 && python setup.py build
|
||||||
|
|
||||||
install-flash-attention-v2: build-flash-attention-v2
|
install-flash-attention-v2: build-flash-attention-v2
|
||||||
|
@ -55,10 +55,8 @@ try:
|
|||||||
FlashSantacoderSharded,
|
FlashSantacoderSharded,
|
||||||
)
|
)
|
||||||
|
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
logger.opt(exception=True).warning(
|
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
||||||
"Could not import Flash Attention enabled models"
|
|
||||||
)
|
|
||||||
FLASH_ATTENTION = False
|
FLASH_ATTENTION = False
|
||||||
|
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
|
@ -188,7 +188,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale
|
self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
@ -308,7 +308,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale
|
self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||||
|
|
||||||
@ -18,10 +20,11 @@ try:
|
|||||||
try:
|
try:
|
||||||
import flash_attn_2_cuda
|
import flash_attn_2_cuda
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Flash Attention V2 is not installed.\n"
|
raise ImportError(
|
||||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
"Flash Attention V2 is not installed.\n"
|
||||||
"or install flash attention v2 with `cd server && make install install-flash-attention-v2`"
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||||
)
|
"or install flash attention v2 with `cd server && make install install-flash-attention-v2`"
|
||||||
|
)
|
||||||
if not (is_sm8x or is_sm90):
|
if not (is_sm8x or is_sm90):
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"GPU with CUDA capability {major} {minor} is not supported for "
|
f"GPU with CUDA capability {major} {minor} is not supported for "
|
||||||
@ -32,26 +35,28 @@ except ImportError as e:
|
|||||||
try:
|
try:
|
||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Flash Attention is not installed.\n"
|
raise ImportError(
|
||||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
"Flash Attention is not installed.\n"
|
||||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||||
) from e
|
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||||
|
) from e
|
||||||
|
|
||||||
if not (is_sm75 or is_sm8x or is_sm90):
|
if not (is_sm75 or is_sm8x or is_sm90):
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"GPU with CUDA capability {major} {minor} is not supported"
|
f"GPU with CUDA capability {major} {minor} is not supported"
|
||||||
) from e
|
) from e
|
||||||
|
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
||||||
HAS_FLASH_ATTN = True
|
HAS_FLASH_ATTN = True
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
out,
|
out,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
):
|
):
|
||||||
if HAS_FLASH_ATTN_V2:
|
if HAS_FLASH_ATTN_V2:
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
@ -76,21 +81,27 @@ def attention(
|
|||||||
if k.shape[1] != q.shape[1]:
|
if k.shape[1] != q.shape[1]:
|
||||||
# MQA expand
|
# MQA expand
|
||||||
if k.shape[1] == 1:
|
if k.shape[1] == 1:
|
||||||
k = k.expand(-1, q.shape[1], -1, -1)
|
k = k.expand(-1, q.shape[1], -1)
|
||||||
# Grouped attention reshape
|
# Grouped attention reshape
|
||||||
else:
|
else:
|
||||||
original_shape = k.shape
|
original_shape = k.shape
|
||||||
k = k.unsqueeze(2).expand(-1, -1, q.shape[1], -1, -1) \
|
k = (
|
||||||
.reshape(original_shape[0], -1, original_shape[1], original_shape[2])
|
k.unsqueeze(2)
|
||||||
|
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
||||||
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
|
)
|
||||||
if v.shape[1] != q.shape[1]:
|
if v.shape[1] != q.shape[1]:
|
||||||
# MQA expand
|
# MQA expand
|
||||||
if v.shape[1] == 1:
|
if v.shape[1] == 1:
|
||||||
v = v.expand(-1, q.shape[1], -1, -1)
|
v = v.expand(-1, q.shape[1], -1)
|
||||||
# Grouped attention reshape
|
# Grouped attention reshape
|
||||||
else:
|
else:
|
||||||
original_shape = v.shape
|
original_shape = v.shape
|
||||||
v = v.unsqueeze(2).expand(-1, -1, q.shape[1], -1, -1) \
|
v = (
|
||||||
.reshape(original_shape[0], -1, original_shape[1], original_shape[2])
|
v.unsqueeze(2)
|
||||||
|
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
||||||
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
|
)
|
||||||
|
|
||||||
return flash_attn_cuda.fwd(
|
return flash_attn_cuda.fwd(
|
||||||
q,
|
q,
|
||||||
|
Loading…
Reference in New Issue
Block a user