From a88c54bb4ce022ddc03b899b0c8018b6625e3a9e Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 19 Apr 2023 12:52:37 +0200 Subject: [PATCH] feat(server): check cuda capability when importing flash models (#201) close #198 --- server/text_generation_server/models/__init__.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 13c74c91..0a29b3cc 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -24,7 +24,18 @@ try: FlashSantacoderSharded, ) - FLASH_ATTENTION = torch.cuda.is_available() + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability() + is_sm75 = major == 7 and minor == 5 + is_sm8x = major == 8 and minor >= 0 + is_sm90 = major == 9 and minor == 0 + + supported = is_sm75 or is_sm8x or is_sm90 + if not supported: + raise ImportError(f"GPU with CUDA capability {major} {minor} is not supported") + FLASH_ATTENTION = True + else: + FLASH_ATTENTION = False except ImportError: logger.opt(exception=True).warning( "Could not import Flash Attention enabled models"