mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Adress comments + fix 2nd path in falcon.
This commit is contained in:
parent
c67539fbcc
commit
d44688b6ac
@ -9,7 +9,6 @@ _PARTITION_SIZE = 512
|
|||||||
|
|
||||||
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
||||||
ENGINE = "triton" if use_triton else "ck"
|
ENGINE = "triton" if use_triton else "ck"
|
||||||
from .flash_attn_triton import triton_attention
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm._C import cache_ops
|
from vllm._C import cache_ops
|
||||||
@ -122,14 +121,12 @@ def paged_attention(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
try:
|
if ENGINE != "triton":
|
||||||
|
try:
|
||||||
import flash_attn_2_cuda
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
if ENGINE == "triton":
|
|
||||||
logger.info("ROCm: using Flash Attention 2 Triton implementation.")
|
|
||||||
elif ENGINE == "ck":
|
|
||||||
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
|
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
try:
|
try:
|
||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
|
|
||||||
@ -158,7 +155,7 @@ except ImportError:
|
|||||||
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
||||||
)
|
)
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"AMD GPU with Rocm capability {major} {minor} is not supported"
|
f"AMD GPU with ROCm capability {major} {minor} is not supported"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
@ -180,7 +177,7 @@ if ENGINE == "ck":
|
|||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
if window_size_left != -1:
|
if window_size_left != -1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
f"ROCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||||
)
|
)
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
q,
|
q,
|
||||||
@ -205,6 +202,7 @@ if ENGINE == "ck":
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif ENGINE == "triton":
|
elif ENGINE == "triton":
|
||||||
|
from .flash_attn_triton import triton_attention
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q,
|
||||||
|
@ -198,9 +198,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
@ -208,7 +206,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
@ -219,7 +217,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
Loading…
Reference in New Issue
Block a user