feat: conditionally include mamba

This commit is contained in:
drbh 2024-02-08 00:34:13 +00:00
parent 2c6ef7c93a
commit b99f784cb3

View File

@ -18,7 +18,6 @@ from text_generation_server.models.galactica import GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.mamba import Mamba
from text_generation_server.models.phi import Phi
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
@ -77,6 +76,15 @@ if FLASH_ATTENTION:
__all__.append(FlashMixtral)
__all__.append(FlashPhi)
MAMBA_AVAILABLE = True
try:
from text_generation_server.models.mamba import Mamba
except ImportError as e:
logger.warning(f"Could not import Mamba: {e}")
MAMBA_AVAILABLE = False
if MAMBA_AVAILABLE:
__all__.append(Mamba)
def get_model(
model_id: str,