allow ROCM_USE_FLASH_ATTN_V2_TRITON=1

This commit is contained in:
fxmarty 2024-05-03 07:36:33 +00:00
parent ca5ea45181
commit 64e65ba3a1

View File

@ -35,7 +35,7 @@ if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
is_sm94 = major == 9 and minor == 4
if IS_ROCM_SYSTEM:
if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true":
if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true" or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1":
ROCM_USE_FLASH_ATTN_V2_TRITON = True
logger.info("ROCm: using Flash Attention 2 Triton implementation.")
else: