diff --git a/integration-tests/models/test_mamba.py b/integration-tests/models/test_mamba.py index 85ed8fd1f5f..baa1964371f 100644 --- a/integration-tests/models/test_mamba.py +++ b/integration-tests/models/test_mamba.py @@ -3,7 +3,7 @@ @pytest.fixture(scope="module") def fused_kernel_mamba_handle(launcher): - with launcher("state-spaces/mamba-130m", num_shard=1) as handle: + with launcher("state-spaces/mamba-130m-hf", num_shard=1) as handle: yield handle diff --git a/router/src/config.rs b/router/src/config.rs index 5d0be9c8b32..59ec52c29c2 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -145,6 +145,7 @@ pub enum Config { LlavaNext(LlavaNext), ClipVisionModel(ClipVisionModel), Mistral, + Mamba, Idefics, Idefics2(Idefics2), Ssm, diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index e5e5aabb2a3..221cbd0b915 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -213,7 +213,7 @@ class ModelType(enum.Enum): "url": "https://huggingface.co/databricks/dbrx-instruct", } MAMBA = { - "type": "ssm", + "type": "mamba", "name": "Mamba", "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj", } @@ -503,7 +503,7 @@ def get_model( # TODO: fix how we determine model type for Mamba if "ssm_cfg" in config_dict: # *only happens in Mamba case - model_type = "ssm" + model_type = "mamba" else: raise RuntimeError( f"Could not determine model type for {model_id} revision {revision}" diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py index 293051c2bf9..07284e6a529 100644 --- a/server/text_generation_server/models/custom_modeling/mamba_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -196,7 +196,10 @@ class MambaModel(nn.Module): def __init__(self, config, weights): super().__init__() prefix = "backbone" - self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) + try: + self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embeddings", weights) + except RuntimeError: + self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) self.blocks = nn.ModuleList( [ ResidualBlock(f"{prefix}.layers.{i}", config, weights, layer_id=i) @@ -206,7 +209,10 @@ def __init__(self, config, weights): self.norm_f = FastRMSNorm.load( f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon ) - self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights) + try: + self.lm_head = SpeculativeHead.load(config, f"{prefix}.embeddings", weights) + except RuntimeError: + self.lm_head = SpeculativeHead.load(config, f"{prefix}.embeddings", weights) self.config = config def forward(