diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 6b015a4a..b0f87047 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -33,7 +33,7 @@ try: supported = is_sm75 or is_sm8x or is_sm90 if not supported: raise ImportError(f"GPU with CUDA capability {major} {minor} is not supported") - FLASH_ATTENTION = supported + FLASH_ATTENTION = True else: FLASH_ATTENTION = False except ImportError: