feat(server): flash attention v2

This commit is contained in:
OlivierDehaene 2023-07-17 17:34:55 +02:00
parent a2cf1bdb2f
commit 107fcfe9b6
6 changed files with 12 additions and 13 deletions

View File

@ -1,4 +1,4 @@
flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec
flash_att_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc
flash-attention:
# Clone flash attention

View File

@ -57,14 +57,13 @@ try:
raise ImportError("CUDA is not available")
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
supported = is_sm75 or is_sm8x or is_sm90
supported = is_sm8x or is_sm90
if not supported:
FLASH_ATT_ERROR_MESSAGE = (
"{} requires a CUDA device with capability 7.5, > 8.0 or 9.0. "
"{} requires a CUDA device with capability > 8.0 or 9.0. "
"No compatible CUDA device found."
)
raise ImportError(

View File

@ -26,7 +26,7 @@ from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
import flash_attn_2_cuda
import dropout_layer_norm
# vllm imports
@ -164,7 +164,7 @@ class FlashLlamaAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn_cuda.fwd(
flash_attn_2_cuda.varlen_fwd(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],

View File

@ -28,7 +28,7 @@ from transformers.models.gpt_neox import GPTNeoXConfig
from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
import flash_attn_2_cuda
# vllm imports
import vllm_cache_ops
@ -153,7 +153,7 @@ class FlashNeoxAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn_cuda.fwd(
flash_attn_2_cuda.varlen_fwd(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],

View File

@ -7,7 +7,7 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
import flash_attn_2_cuda
# vllm imports
import vllm_cache_ops
@ -187,7 +187,7 @@ class FlashRWAttention(torch.nn.Module):
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
# flash attention
flash_attn_cuda.fwd(
flash_attn_2_cuda.varlen_fwd(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
@ -322,7 +322,7 @@ class FlashRWLargeAttention(torch.nn.Module):
)
# flash attention
flash_attn_cuda.fwd(
flash_attn_2_cuda.varlen_fwd(
query,
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),

View File

@ -6,7 +6,7 @@ from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
import flash_attn_2_cuda
# vllm imports
import vllm_cache_ops
@ -275,7 +275,7 @@ class FlashMQAttention(torch.nn.Module):
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
# flash attention
flash_attn_cuda.fwd(
flash_attn_2_cuda.varlen_fwd(
query,
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),