mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
feat(server): check cuda capability when importing flash models
This commit is contained in:
parent
2475aede61
commit
7c16352d1e
@ -24,7 +24,18 @@ try:
|
|||||||
FlashSantacoderSharded,
|
FlashSantacoderSharded,
|
||||||
)
|
)
|
||||||
|
|
||||||
FLASH_ATTENTION = torch.cuda.is_available()
|
if torch.cuda.is_available():
|
||||||
|
major, minor = torch.cuda.get_device_capability()
|
||||||
|
is_sm75 = major == 7 and minor == 5
|
||||||
|
is_sm8x = major == 8 and minor >= 0
|
||||||
|
is_sm90 = major == 9 and minor == 0
|
||||||
|
|
||||||
|
supported = is_sm75 or is_sm8x or is_sm90
|
||||||
|
if not supported:
|
||||||
|
raise ImportError(f"GPU with CUDA capability {major} {minor} is not supported")
|
||||||
|
FLASH_ATTENTION = supported
|
||||||
|
else:
|
||||||
|
FLASH_ATTENTION = False
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.opt(exception=True).warning("Could not import Flash Attention enabled models")
|
logger.opt(exception=True).warning("Could not import Flash Attention enabled models")
|
||||||
FLASH_ATTENTION = False
|
FLASH_ATTENTION = False
|
||||||
|
Loading…
Reference in New Issue
Block a user