From cd355d08a908e095e4f4a00ea493e4dc061ba6df Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 25 Sep 2024 03:37:12 +0200 Subject: [PATCH] Fixing mamba by using the transformers version. --- integration-tests/models/test_mamba.py | 2 +- router/src/config.rs | 1 + server/text_generation_server/models/__init__.py | 4 ++-- .../models/custom_modeling/mamba_modeling.py | 10 ++++++++-- 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/integration-tests/models/test_mamba.py b/integration-tests/models/test_mamba.py index 85ed8fd1..baa19643 100644 --- a/integration-tests/models/test_mamba.py +++ b/integration-tests/models/test_mamba.py @@ -3,7 +3,7 @@ import pytest @pytest.fixture(scope="module") def fused_kernel_mamba_handle(launcher): - with launcher("state-spaces/mamba-130m", num_shard=1) as handle: + with launcher("state-spaces/mamba-130m-hf", num_shard=1) as handle: yield handle diff --git a/router/src/config.rs b/router/src/config.rs index 7139b923..ce066ad0 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -145,6 +145,7 @@ pub enum Config { LlavaNext(LlavaNext), ClipVisionModel(ClipVisionModel), Mistral, + Mamba, Idefics, Mllama, Idefics2(Idefics2), diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index d3015408..6ef4f903 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -226,7 +226,7 @@ class ModelType(enum.Enum): "url": "https://huggingface.co/databricks/dbrx-instruct", } MAMBA = { - "type": "ssm", + "type": "mamba", "name": "Mamba", "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj", } @@ -555,7 +555,7 @@ def get_model( # TODO: fix how we determine model type for Mamba if "ssm_cfg" in config_dict: # *only happens in Mamba case - model_type = "ssm" + model_type = "mamba" else: raise RuntimeError( f"Could not determine model type for {model_id} revision {revision}" diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py index 293051c2..07284e6a 100644 --- a/server/text_generation_server/models/custom_modeling/mamba_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -196,7 +196,10 @@ class MambaModel(nn.Module): def __init__(self, config, weights): super().__init__() prefix = "backbone" - self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) + try: + self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embeddings", weights) + except RuntimeError: + self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) self.blocks = nn.ModuleList( [ ResidualBlock(f"{prefix}.layers.{i}", config, weights, layer_id=i) @@ -206,7 +209,10 @@ class MambaModel(nn.Module): self.norm_f = FastRMSNorm.load( f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon ) - self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights) + try: + self.lm_head = SpeculativeHead.load(config, f"{prefix}.embeddings", weights) + except RuntimeError: + self.lm_head = SpeculativeHead.load(config, f"{prefix}.embeddings", weights) self.config = config def forward(