2023-07-18 14:21:18 +00:00
import os
import torch
from loguru import logger
2023-11-27 13:08:12 +00:00
from text_generation_server . utils . import_utils import IS_CUDA_SYSTEM , IS_ROCM_SYSTEM
2023-07-18 14:21:18 +00:00
if os . getenv ( " USE_FLASH_ATTENTION " , " " ) . lower ( ) == " false " :
raise ImportError ( " `USE_FLASH_ATTENTION` is false. " )
if not torch . cuda . is_available ( ) :
raise ImportError ( " CUDA is not available " )
major , minor = torch . cuda . get_device_capability ( )
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor > = 0
is_sm90 = major == 9 and minor == 0
HAS_FLASH_ATTN = False
2023-11-27 13:08:12 +00:00
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
2023-07-18 14:21:18 +00:00
try :
try :
import flash_attn_2_cuda
except ImportError :
raise ImportError (
" Flash Attention V2 is not installed. \n "
" Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
" or install flash attention v2 with `cd server && make install install-flash-attention-v2` "
)
if not ( is_sm8x or is_sm90 ) :
raise ImportError (
f " GPU with CUDA capability { major } { minor } is not supported for "
" Flash Attention V2 "
)
2023-11-27 13:08:12 +00:00
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
2023-07-18 14:21:18 +00:00
except ImportError as e :
try :
import flash_attn_cuda
except ImportError :
raise ImportError (
" Flash Attention is not installed. \n "
" Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
" or install flash attention with `cd server && make install install-flash-attention` "
) from e
2023-11-27 13:08:12 +00:00
if IS_CUDA_SYSTEM and not ( is_sm75 or is_sm8x or is_sm90 ) :
2023-07-18 14:21:18 +00:00
raise ImportError (
f " GPU with CUDA capability { major } { minor } is not supported "
) from e
2023-11-27 13:08:12 +00:00
elif IS_ROCM_SYSTEM :
for idx in range ( torch . cuda . device_count ( ) ) :
if " MI210 " not in torch . cuda . get_device_name ( idx ) and " MI250 " not in torch . cuda . get_device_name ( idx ) :
raise ImportError (
f " AMD GPU { torch . cuda . get_device_name ( idx ) } does not support flash-attention "
)
2023-07-18 14:21:18 +00:00
logger . warning ( f " Unable to use Flash Attention V2: { e } " )
HAS_FLASH_ATTN = True
def attention (
q ,
k ,
v ,
out ,
cu_seqlens ,
max_s ,
softmax_scale ,
2023-09-28 07:55:47 +00:00
window_size_left = - 1 ,
2023-07-18 14:21:18 +00:00
) :
2023-11-27 13:08:12 +00:00
if HAS_FLASH_ATTN_V2_CUDA :
2023-07-18 14:21:18 +00:00
return flash_attn_2_cuda . varlen_fwd (
q ,
k ,
v ,
out ,
cu_seqlens ,
cu_seqlens ,
max_s ,
max_s ,
0.0 ,
softmax_scale ,
False ,
True ,
2023-09-28 07:55:47 +00:00
window_size_left ,
0 ,
2023-07-18 14:21:18 +00:00
False ,
None ,
)
2023-11-27 13:08:12 +00:00
elif HAS_FLASH_ATTN_V2_ROCM :
if window_size_left != - 1 :
raise ValueError ( f " RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left= { window_size_left } ). " )
# RoCm flash API does not take the window_size_left and window_size_right arguments.
return flash_attn_2_cuda . varlen_fwd (
q ,
k ,
v ,
out ,
cu_seqlens ,
cu_seqlens ,
max_s ,
max_s ,
0.0 ,
softmax_scale ,
False ,
True ,
False ,
None ,
)
elif HAS_FLASH_ATTN :
2023-10-02 18:53:14 +00:00
if window_size_left != - 1 :
2023-09-28 07:55:47 +00:00
raise NotImplementedError (
" window_size_left is only available with flash attn v2 "
)
2023-07-18 14:21:18 +00:00
# Flash attention v1 requires q, k and v to have the same number of heads
if k . shape [ 1 ] != q . shape [ 1 ] :
# MQA expand
if k . shape [ 1 ] == 1 :
k = k . expand ( - 1 , q . shape [ 1 ] , - 1 )
# Grouped attention reshape
else :
original_shape = k . shape
k = (
k . unsqueeze ( 2 )
. expand ( - 1 , - 1 , q . shape [ 1 ] / / k . shape [ 1 ] , - 1 )
. reshape ( original_shape [ 0 ] , - 1 , original_shape [ 2 ] )
)
if v . shape [ 1 ] != q . shape [ 1 ] :
# MQA expand
if v . shape [ 1 ] == 1 :
v = v . expand ( - 1 , q . shape [ 1 ] , - 1 )
# Grouped attention reshape
else :
original_shape = v . shape
v = (
v . unsqueeze ( 2 )
. expand ( - 1 , - 1 , q . shape [ 1 ] / / v . shape [ 1 ] , - 1 )
. reshape ( original_shape [ 0 ] , - 1 , original_shape [ 2 ] )
)
return flash_attn_cuda . fwd (
q ,
k ,
v ,
out ,
cu_seqlens ,
cu_seqlens ,
max_s ,
max_s ,
0.0 ,
softmax_scale ,
False ,
True ,
False ,
0 ,
None ,
)
raise NotImplementedError ( " flash attention is not installed " )