Adress comments + fix 2nd path in falcon.

This commit is contained in:
Nicolas Patry 2024-05-31 12:43:13 +00:00
parent c67539fbcc
commit d44688b6ac
2 changed files with 38 additions and 42 deletions

View File

@ -9,7 +9,6 @@ _PARTITION_SIZE = 512
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck"
from .flash_attn_triton import triton_attention
try:
from vllm._C import cache_ops
@ -122,44 +121,42 @@ def paged_attention(
)
try:
import flash_attn_2_cuda
if ENGINE == "triton":
logger.info("ROCm: using Flash Attention 2 Triton implementation.")
elif ENGINE == "ck":
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
except ImportError:
if ENGINE != "triton":
try:
import flash_attn_cuda
import flash_attn_2_cuda
ENGINE = "v1"
logger.info("ROCm: using Flash Attention 1")
except ImportError as e:
if major >= 8:
architecture_suffix = f"-{SYSTEM}"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
elif is_sm75:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
else:
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
except ImportError:
try:
import flash_attn_cuda
for idx in range(torch.cuda.device_count()):
name = torch.cuda.get_device_name(idx)
if "MI210" not in name and "MI250" not in name:
raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
)
raise ImportError(
f"AMD GPU with Rocm capability {major} {minor} is not supported"
) from e
ENGINE = "v1"
logger.info("ROCm: using Flash Attention 1")
except ImportError as e:
if major >= 8:
architecture_suffix = f"-{SYSTEM}"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
elif is_sm75:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
else:
for idx in range(torch.cuda.device_count()):
name = torch.cuda.get_device_name(idx)
if "MI210" not in name and "MI250" not in name:
raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
)
raise ImportError(
f"AMD GPU with ROCm capability {major} {minor} is not supported"
) from e
SUPPORTS_WINDOWING = ENGINE != "v1"
@ -180,7 +177,7 @@ if ENGINE == "ck":
raise ValueError("`window_size_left` must be > 0 or -1")
if window_size_left != -1:
raise ValueError(
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
f"ROCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
return flash_attn_2_cuda.varlen_fwd(
q,
@ -205,6 +202,7 @@ if ENGINE == "ck":
)
elif ENGINE == "triton":
from .flash_attn_triton import triton_attention
def attention(
q,

View File

@ -198,9 +198,7 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output
attn_output = torch.empty_like(query)
@ -208,7 +206,7 @@ class FlashRWAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
@ -219,7 +217,7 @@ class FlashRWAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],