From 064e040ee30635f6abb25acb2d7587489a80afde Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 18 Dec 2024 14:58:27 +0000 Subject: [PATCH] fix: improve text model loading --- .../text_generation_server/models/custom_modeling/idefics2.py | 2 +- server/text_generation_server/models/custom_modeling/vlm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index b1967ec3..6c1d5823 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -723,7 +723,7 @@ class Idefics3ForConditionalGeneration(nn.Module): vision_config = config.vision_config self.text_model = load_text_model( - prefix="model" if not prefix else f"{prefix}.model", + prefix=f"{prefix}.model.text_model" if prefix else "model.text_model", config=config.text_config, weights=weights, name="text_model", diff --git a/server/text_generation_server/models/custom_modeling/vlm.py b/server/text_generation_server/models/custom_modeling/vlm.py index 04edd0a4..82e409a6 100644 --- a/server/text_generation_server/models/custom_modeling/vlm.py +++ b/server/text_generation_server/models/custom_modeling/vlm.py @@ -4,7 +4,7 @@ def load_text_model(prefix, config, weights, name=None): FlashLlamaForCausalLM, ) - return FlashLlamaForCausalLM(f"{prefix}.text_model", config, weights) + return FlashLlamaForCausalLM(prefix, config, weights) elif config.model_type == "mistral": from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM,