mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: conditionally include mamba
This commit is contained in:
parent
2c6ef7c93a
commit
b99f784cb3
@ -18,7 +18,6 @@ from text_generation_server.models.galactica import GalacticaSharded
|
|||||||
from text_generation_server.models.santacoder import SantaCoder
|
from text_generation_server.models.santacoder import SantaCoder
|
||||||
from text_generation_server.models.t5 import T5Sharded
|
from text_generation_server.models.t5 import T5Sharded
|
||||||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||||
from text_generation_server.models.mamba import Mamba
|
|
||||||
from text_generation_server.models.phi import Phi
|
from text_generation_server.models.phi import Phi
|
||||||
|
|
||||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||||
@ -77,6 +76,15 @@ if FLASH_ATTENTION:
|
|||||||
__all__.append(FlashMixtral)
|
__all__.append(FlashMixtral)
|
||||||
__all__.append(FlashPhi)
|
__all__.append(FlashPhi)
|
||||||
|
|
||||||
|
MAMBA_AVAILABLE = True
|
||||||
|
try:
|
||||||
|
from text_generation_server.models.mamba import Mamba
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning(f"Could not import Mamba: {e}")
|
||||||
|
MAMBA_AVAILABLE = False
|
||||||
|
|
||||||
|
if MAMBA_AVAILABLE:
|
||||||
|
__all__.append(Mamba)
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
Loading…
Reference in New Issue
Block a user