This commit is contained in:
OlivierDehaene 2023-07-18 12:36:27 +02:00
parent bc2f351980
commit d186b13c59
4 changed files with 38 additions and 29 deletions

View File

@ -1,4 +1,4 @@
flash_att_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc
flash-attention-v2: flash-attention-v2:
# Clone flash attention # Clone flash attention
@ -6,7 +6,7 @@ flash-attention-v2:
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2 git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2
build-flash-attention-v2: flash-attention-v2 build-flash-attention-v2: flash-attention-v2
cd flash-attention-v2 && git fetch && git checkout $(flash_att_commit) cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit)
cd flash-attention-v2 && python setup.py build cd flash-attention-v2 && python setup.py build
install-flash-attention-v2: build-flash-attention-v2 install-flash-attention-v2: build-flash-attention-v2

View File

@ -55,10 +55,8 @@ try:
FlashSantacoderSharded, FlashSantacoderSharded,
) )
except ImportError: except ImportError as e:
logger.opt(exception=True).warning( logger.warning(f"Could not import Flash Attention enabled models: {e}")
"Could not import Flash Attention enabled models"
)
FLASH_ATTENTION = False FLASH_ATTENTION = False
if FLASH_ATTENTION: if FLASH_ATTENTION:

View File

@ -188,7 +188,7 @@ class FlashRWAttention(torch.nn.Module):
attn_output, attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale self.softmax_scale,
) )
# Decode # Decode
else: else:
@ -308,7 +308,7 @@ class FlashRWLargeAttention(torch.nn.Module):
attn_output, attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale self.softmax_scale,
) )
# Decode # Decode
else: else:

View File

@ -1,6 +1,8 @@
import os import os
import torch import torch
from loguru import logger
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.") raise ImportError("`USE_FLASH_ATTENTION` is false.")
@ -18,10 +20,11 @@ try:
try: try:
import flash_attn_2_cuda import flash_attn_2_cuda
except ImportError: except ImportError:
raise ImportError("Flash Attention V2 is not installed.\n" raise ImportError(
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " "Flash Attention V2 is not installed.\n"
"or install flash attention v2 with `cd server && make install install-flash-attention-v2`" "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
) "or install flash attention v2 with `cd server && make install install-flash-attention-v2`"
)
if not (is_sm8x or is_sm90): if not (is_sm8x or is_sm90):
raise ImportError( raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported for " f"GPU with CUDA capability {major} {minor} is not supported for "
@ -32,26 +35,28 @@ except ImportError as e:
try: try:
import flash_attn_cuda import flash_attn_cuda
except ImportError: except ImportError:
raise ImportError("Flash Attention is not installed.\n" raise ImportError(
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " "Flash Attention is not installed.\n"
"or install flash attention with `cd server && make install install-flash-attention`" "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
) from e "or install flash attention with `cd server && make install install-flash-attention`"
) from e
if not (is_sm75 or is_sm8x or is_sm90): if not (is_sm75 or is_sm8x or is_sm90):
raise ImportError( raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported" f"GPU with CUDA capability {major} {minor} is not supported"
) from e ) from e
logger.warning(f"Unable to use Flash Attention V2: {e}")
HAS_FLASH_ATTN = True HAS_FLASH_ATTN = True
def attention( def attention(
q, q,
k, k,
v, v,
out, out,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
): ):
if HAS_FLASH_ATTN_V2: if HAS_FLASH_ATTN_V2:
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
@ -76,21 +81,27 @@ def attention(
if k.shape[1] != q.shape[1]: if k.shape[1] != q.shape[1]:
# MQA expand # MQA expand
if k.shape[1] == 1: if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1, -1) k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape # Grouped attention reshape
else: else:
original_shape = k.shape original_shape = k.shape
k = k.unsqueeze(2).expand(-1, -1, q.shape[1], -1, -1) \ k = (
.reshape(original_shape[0], -1, original_shape[1], original_shape[2]) k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]: if v.shape[1] != q.shape[1]:
# MQA expand # MQA expand
if v.shape[1] == 1: if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1, -1) v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape # Grouped attention reshape
else: else:
original_shape = v.shape original_shape = v.shape
v = v.unsqueeze(2).expand(-1, -1, q.shape[1], -1, -1) \ v = (
.reshape(original_shape[0], -1, original_shape[1], original_shape[2]) v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
return flash_attn_cuda.fwd( return flash_attn_cuda.fwd(
q, q,