mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 14:22:08 +00:00
* feat: add ruff and resolve issue * fix: update client exports and adjust after rebase * fix: adjust syntax to avoid circular import * fix: adjust client ruff settings * fix: lint and refactor import check and avoid model enum as global names * fix: improve fbgemm_gpu check and lints * fix: update lints * fix: prefer comparing model enum over str * fix: adjust lints and ignore specific rules * fix: avoid unneeded quantize check
49 lines
1.8 KiB
Python
49 lines
1.8 KiB
Python
def load_text_model(prefix, config, weights, name=None):
|
|
if config.model_type == "llama":
|
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
|
FlashLlamaForCausalLM,
|
|
)
|
|
|
|
return FlashLlamaForCausalLM(prefix, config, weights)
|
|
elif config.model_type == "mistral":
|
|
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
|
FlashMistralForCausalLM,
|
|
)
|
|
|
|
return FlashMistralForCausalLM(prefix, config, weights, name=name)
|
|
elif config.model_type == "gemma":
|
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
|
FlashGemmaForCausalLM,
|
|
)
|
|
|
|
return FlashGemmaForCausalLM(prefix, config, weights, causal=False)
|
|
elif config.model_type == "paligemma":
|
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
|
FlashGemmaForCausalLM,
|
|
)
|
|
|
|
return FlashGemmaForCausalLM(prefix, config, weights)
|
|
else:
|
|
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
|
|
|
|
|
def load_vision_model(prefix, config, weights):
|
|
if config.model_type == "clip_vision_model":
|
|
from text_generation_server.models.custom_modeling.clip import (
|
|
CLIPVisionTransformer,
|
|
)
|
|
|
|
return CLIPVisionTransformer(
|
|
prefix=f"{prefix}.vision_model", config=config, weights=weights
|
|
)
|
|
if config.model_type == "siglip_vision_model":
|
|
from text_generation_server.models.custom_modeling.siglip import (
|
|
SiglipVisionTransformer,
|
|
)
|
|
|
|
return SiglipVisionTransformer(
|
|
prefix="vision_tower.vision_model", config=config, weights=weights
|
|
)
|
|
else:
|
|
raise RuntimeError(f"Unsupported model type {config.model_type}")
|