diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 961d66b9..a952f060 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -18,7 +18,6 @@ from text_generation_server.models.galactica import GalacticaSharded from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.gpt_neox import GPTNeoxSharded -from text_generation_server.models.mamba import Mamba from text_generation_server.models.phi import Phi # The flag below controls whether to allow TF32 on matmul. This flag defaults to False @@ -77,6 +76,15 @@ if FLASH_ATTENTION: __all__.append(FlashMixtral) __all__.append(FlashPhi) +MAMBA_AVAILABLE = True +try: + from text_generation_server.models.mamba import Mamba +except ImportError as e: + logger.warning(f"Could not import Mamba: {e}") + MAMBA_AVAILABLE = False + +if MAMBA_AVAILABLE: + __all__.append(Mamba) def get_model( model_id: str,