From 64e65ba3a1d1a22fae8ba6d7b82175f033deb01e Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 3 May 2024 07:36:33 +0000 Subject: [PATCH] allow ROCM_USE_FLASH_ATTN_V2_TRITON=1 --- server/text_generation_server/utils/flash_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index a15cbbbc..51f7f5fc 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -35,7 +35,7 @@ if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: is_sm94 = major == 9 and minor == 4 if IS_ROCM_SYSTEM: - if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true": + if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true" or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1": ROCM_USE_FLASH_ATTN_V2_TRITON = True logger.info("ROCm: using Flash Attention 2 Triton implementation.") else: