fix: improve condtional and error message

This commit is contained in:
drbh 2024-08-01 16:17:29 +00:00
parent cae28dcbf1
commit 5b649d67c4

View File

@ -171,8 +171,9 @@ def paged_attention(
try:
if major <= 8:
raise ImportError("Flash Attention V2 requires CUDA 11.0 or higher")
is_ampere_or_newer = major >= 8 and minor >= 0
if not is_ampere_or_newer:
raise ImportError("FlashAttention only supports Ampere GPUs or newer.")
import flash_attn_2_cuda