diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 0a211ec3..9b84a125 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -189,11 +189,14 @@ if FLASH_ATTENTION: __all__.append(IdeficsCausalLM) MAMBA_AVAILABLE = True -try: - from text_generation_server.models.mamba import Mamba -except ImportError as e: - log_master(logger.warning, f"Could not import Mamba: {e}") +if SYSTEM == "cpu": MAMBA_AVAILABLE = False +else: + try: + from text_generation_server.models.mamba import Mamba + except ImportError as e: + log_master(logger.warning, f"Could not import Mamba: {e}") + MAMBA_AVAILABLE = False if MAMBA_AVAILABLE: __all__.append(Mamba)