Adress comments + fix 2nd path in falcon.

This commit is contained in:
Nicolas Patry 2024-05-31 12:43:13 +00:00
parent c67539fbcc
commit d44688b6ac
2 changed files with 38 additions and 42 deletions

View File

@ -9,7 +9,6 @@ _PARTITION_SIZE = 512
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"} use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck" ENGINE = "triton" if use_triton else "ck"
from .flash_attn_triton import triton_attention
try: try:
from vllm._C import cache_ops from vllm._C import cache_ops
@ -122,44 +121,42 @@ def paged_attention(
) )
try: if ENGINE != "triton":
import flash_attn_2_cuda
if ENGINE == "triton":
logger.info("ROCm: using Flash Attention 2 Triton implementation.")
elif ENGINE == "ck":
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
except ImportError:
try: try:
import flash_attn_cuda import flash_attn_2_cuda
ENGINE = "v1" logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
logger.info("ROCm: using Flash Attention 1") except ImportError:
except ImportError as e: try:
if major >= 8: import flash_attn_cuda
architecture_suffix = f"-{SYSTEM}"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
elif is_sm75:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
else:
for idx in range(torch.cuda.device_count()): ENGINE = "v1"
name = torch.cuda.get_device_name(idx) logger.info("ROCm: using Flash Attention 1")
if "MI210" not in name and "MI250" not in name: except ImportError as e:
raise ImportError( if major >= 8:
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" architecture_suffix = f"-{SYSTEM}"
) raise ImportError(
raise ImportError( "Flash Attention V2 is not installed.\n"
f"AMD GPU with Rocm capability {major} {minor} is not supported" "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
) from e f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
elif is_sm75:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
else:
for idx in range(torch.cuda.device_count()):
name = torch.cuda.get_device_name(idx)
if "MI210" not in name and "MI250" not in name:
raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
)
raise ImportError(
f"AMD GPU with ROCm capability {major} {minor} is not supported"
) from e
SUPPORTS_WINDOWING = ENGINE != "v1" SUPPORTS_WINDOWING = ENGINE != "v1"
@ -180,7 +177,7 @@ if ENGINE == "ck":
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")
if window_size_left != -1: if window_size_left != -1:
raise ValueError( raise ValueError(
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." f"ROCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
) )
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
q, q,
@ -205,6 +202,7 @@ if ENGINE == "ck":
) )
elif ENGINE == "triton": elif ENGINE == "triton":
from .flash_attn_triton import triton_attention
def attention( def attention(
q, q,

View File

@ -198,9 +198,7 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary # Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
paged_attention.reshape_and_cache( reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
@ -208,7 +206,7 @@ class FlashRWAttention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
flash_attn.attention( attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
@ -219,7 +217,7 @@ class FlashRWAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention.attention( paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],