def load_text_model(prefix, config, weights, name=None): if config.model_type == "llama": from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, ) return FlashLlamaForCausalLM(prefix, config, weights, name=name) elif config.model_type == "mistral": from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, ) return FlashMistralForCausalLM(prefix, config, weights, name=name) elif config.model_type == "gemma": from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( FlashGemmaForCausalLM, ) return FlashGemmaForCausalLM(prefix, config, weights, causal=False) elif config.model_type == "gemma2": from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( FlashGemma2ForCausalLM, ) return FlashGemma2ForCausalLM(prefix, config, weights) elif config.model_type == "gemma3" or config.model_type == "gemma3_text": from text_generation_server.models.custom_modeling.flash_gemma3_modeling import ( FlashGemma3ForCausalLM, ) return FlashGemma3ForCausalLM(prefix, config, weights) elif config.model_type == "paligemma": from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( FlashGemmaForCausalLM, ) return FlashGemmaForCausalLM(prefix, config, weights) else: raise RuntimeError(f"Unsupported model type {config.model_type}") def load_vision_model(prefix, config, weights): if config.model_type == "clip_vision_model": from text_generation_server.models.custom_modeling.clip import ( CLIPVisionTransformer, ) return CLIPVisionTransformer( prefix=f"{prefix}.vision_model", config=config, weights=weights ) if ( config.model_type == "siglip_vision_model" or config.model_type == "gemma3_vision" ): from text_generation_server.models.custom_modeling.siglip import ( SiglipVisionTransformer, ) # TODO: ensure that using the prefix doesn't break any existing models # that rely on the old prefix (update the old models if necessary) return SiglipVisionTransformer( # prefix="vision_model.vision_model", config=config, weights=weights prefix=f"{prefix}.vision_model", config=config, weights=weights, ) else: raise RuntimeError(f"Unsupported model type {config.model_type}")