From 7c16352d1ee9d5799dc2513dba5ceff5934bd15b Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 19 Apr 2023 12:48:04 +0200 Subject: [PATCH] feat(server): check cuda capability when importing flash models --- server/text_generation_server/models/__init__.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 9c1ea3b0..6b015a4a 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -24,7 +24,18 @@ try: FlashSantacoderSharded, ) - FLASH_ATTENTION = torch.cuda.is_available() + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability() + is_sm75 = major == 7 and minor == 5 + is_sm8x = major == 8 and minor >= 0 + is_sm90 = major == 9 and minor == 0 + + supported = is_sm75 or is_sm8x or is_sm90 + if not supported: + raise ImportError(f"GPU with CUDA capability {major} {minor} is not supported") + FLASH_ATTENTION = supported + else: + FLASH_ATTENTION = False except ImportError: logger.opt(exception=True).warning("Could not import Flash Attention enabled models") FLASH_ATTENTION = False