mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
55 lines
2.0 KiB
Python
55 lines
2.0 KiB
Python
def load_text_model(prefix, config, weights, name=None):
|
|
if config.model_type == "llama":
|
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
|
FlashLlamaForCausalLM,
|
|
)
|
|
|
|
return FlashLlamaForCausalLM(prefix, config, weights)
|
|
elif config.model_type == "mistral":
|
|
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
|
FlashMistralForCausalLM,
|
|
)
|
|
|
|
return FlashMistralForCausalLM(prefix, config, weights, name=name)
|
|
elif config.model_type == "gemma":
|
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
|
FlashGemmaForCausalLM,
|
|
)
|
|
|
|
return FlashGemmaForCausalLM(prefix, config, weights, causal=False)
|
|
elif config.model_type == "gemma2":
|
|
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
|
FlashGemma2ForCausalLM,
|
|
)
|
|
|
|
return FlashGemma2ForCausalLM(prefix, config, weights)
|
|
elif config.model_type == "paligemma":
|
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
|
FlashGemmaForCausalLM,
|
|
)
|
|
|
|
return FlashGemmaForCausalLM(prefix, config, weights)
|
|
else:
|
|
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
|
|
|
|
|
def load_vision_model(prefix, config, weights):
|
|
if config.model_type == "clip_vision_model":
|
|
from text_generation_server.models.custom_modeling.clip import (
|
|
CLIPVisionTransformer,
|
|
)
|
|
|
|
return CLIPVisionTransformer(
|
|
prefix=f"{prefix}.vision_model", config=config, weights=weights
|
|
)
|
|
if config.model_type == "siglip_vision_model":
|
|
from text_generation_server.models.custom_modeling.siglip import (
|
|
SiglipVisionTransformer,
|
|
)
|
|
|
|
return SiglipVisionTransformer(
|
|
prefix="vision_tower.vision_model", config=config, weights=weights
|
|
)
|
|
else:
|
|
raise RuntimeError(f"Unsupported model type {config.model_type}")
|