mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
feat(server): flash attention v2
This commit is contained in:
parent
a2cf1bdb2f
commit
107fcfe9b6
@ -1,4 +1,4 @@
|
|||||||
flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec
|
flash_att_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc
|
||||||
|
|
||||||
flash-attention:
|
flash-attention:
|
||||||
# Clone flash attention
|
# Clone flash attention
|
||||||
|
@ -57,14 +57,13 @@ try:
|
|||||||
raise ImportError("CUDA is not available")
|
raise ImportError("CUDA is not available")
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
is_sm75 = major == 7 and minor == 5
|
|
||||||
is_sm8x = major == 8 and minor >= 0
|
is_sm8x = major == 8 and minor >= 0
|
||||||
is_sm90 = major == 9 and minor == 0
|
is_sm90 = major == 9 and minor == 0
|
||||||
|
|
||||||
supported = is_sm75 or is_sm8x or is_sm90
|
supported = is_sm8x or is_sm90
|
||||||
if not supported:
|
if not supported:
|
||||||
FLASH_ATT_ERROR_MESSAGE = (
|
FLASH_ATT_ERROR_MESSAGE = (
|
||||||
"{} requires a CUDA device with capability 7.5, > 8.0 or 9.0. "
|
"{} requires a CUDA device with capability > 8.0 or 9.0. "
|
||||||
"No compatible CUDA device found."
|
"No compatible CUDA device found."
|
||||||
)
|
)
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
|
@ -26,7 +26,7 @@ from transformers.activations import ACT2FN
|
|||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda
|
import flash_attn_2_cuda
|
||||||
import dropout_layer_norm
|
import dropout_layer_norm
|
||||||
|
|
||||||
# vllm imports
|
# vllm imports
|
||||||
@ -164,7 +164,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_2_cuda.varlen_fwd(
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
qkv[:, 1],
|
qkv[:, 1],
|
||||||
qkv[:, 2],
|
qkv[:, 2],
|
||||||
|
@ -28,7 +28,7 @@ from transformers.models.gpt_neox import GPTNeoXConfig
|
|||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
# vllm imports
|
# vllm imports
|
||||||
import vllm_cache_ops
|
import vllm_cache_ops
|
||||||
@ -153,7 +153,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_2_cuda.varlen_fwd(
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
qkv[:, 1],
|
qkv[:, 1],
|
||||||
qkv[:, 2],
|
qkv[:, 2],
|
||||||
|
@ -7,7 +7,7 @@ from transformers.configuration_utils import PretrainedConfig
|
|||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
# vllm imports
|
# vllm imports
|
||||||
import vllm_cache_ops
|
import vllm_cache_ops
|
||||||
@ -187,7 +187,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
||||||
|
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_2_cuda.varlen_fwd(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
@ -322,7 +322,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_2_cuda.varlen_fwd(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=2, index=0),
|
torch.select(kv, dim=2, index=0),
|
||||||
torch.select(kv, dim=2, index=1),
|
torch.select(kv, dim=2, index=1),
|
||||||
|
@ -6,7 +6,7 @@ from transformers.activations import ACT2FN
|
|||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
# vllm imports
|
# vllm imports
|
||||||
import vllm_cache_ops
|
import vllm_cache_ops
|
||||||
@ -275,7 +275,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
|
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
|
||||||
|
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_2_cuda.varlen_fwd(
|
||||||
query,
|
query,
|
||||||
torch.select(key_value, dim=1, index=0),
|
torch.select(key_value, dim=1, index=0),
|
||||||
torch.select(key_value, dim=1, index=1),
|
torch.select(key_value, dim=1, index=1),
|
||||||
|
Loading…
Reference in New Issue
Block a user