diff --git a/server/text_generation_server/models/custom_modeling/vlm.py b/server/text_generation_server/models/custom_modeling/vlm.py index e5c44045..82e409a6 100644 --- a/server/text_generation_server/models/custom_modeling/vlm.py +++ b/server/text_generation_server/models/custom_modeling/vlm.py @@ -17,6 +17,12 @@ def load_text_model(prefix, config, weights, name=None): ) return FlashGemmaForCausalLM(prefix, config, weights, causal=False) + elif config.model_type == "gemma2": + from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( + FlashGemma2ForCausalLM, + ) + + return FlashGemma2ForCausalLM(prefix, config, weights) elif config.model_type == "paligemma": from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( FlashGemmaForCausalLM,