diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index eb5a8de7..2d735227 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -180,7 +180,7 @@ except ImportError as e: if MAMBA_AVAILABLE: __all__.append(Mamba) -FLASH_TRANSFORMERS_BACKEND = True +FLASH_TRANSFORMERS_BACKEND = torch.cuda.is_available() try: from text_generation_server.models.transformers_flash_causal_lm import ( TransformersFlashCausalLM,