mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 06:12:07 +00:00
* feat: expand vlm support and add image token logic and tests * fix: avoid unused perceiver config * feat: integrate image tokens into inputs embeds * feat: add simple idefics3 test * feat: update docs, image token logic and weight names * fix: improve image processing * feat: improve prefix for idefics3 * fix: bump idefics3 tests and snapshots * fix: improve text model loading * feat: consolidate changes with existing vlms and add support and test for smolvlm * fix: create new idefic3 file, simplify logic and adjust llama weight loading * fix: lint with ruff * fix: clean up idefics 3 and improve prefix handling * fix: improve typing * fix: improve prompt_split_image with ref to original impl * fix: adjust ruff lints and small refactors * fix: adjust FlashLlamaModel prefix logic
55 lines
2.1 KiB
Python
55 lines
2.1 KiB
Python
def load_text_model(prefix, config, weights, name=None):
|
|
if config.model_type == "llama":
|
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
|
FlashLlamaForCausalLM,
|
|
)
|
|
|
|
return FlashLlamaForCausalLM(prefix, config, weights, name=name)
|
|
elif config.model_type == "mistral":
|
|
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
|
FlashMistralForCausalLM,
|
|
)
|
|
|
|
return FlashMistralForCausalLM(prefix, config, weights, name=name)
|
|
elif config.model_type == "gemma":
|
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
|
FlashGemmaForCausalLM,
|
|
)
|
|
|
|
return FlashGemmaForCausalLM(prefix, config, weights, causal=False)
|
|
elif config.model_type == "gemma2":
|
|
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
|
FlashGemma2ForCausalLM,
|
|
)
|
|
|
|
return FlashGemma2ForCausalLM(prefix, config, weights)
|
|
elif config.model_type == "paligemma":
|
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
|
FlashGemmaForCausalLM,
|
|
)
|
|
|
|
return FlashGemmaForCausalLM(prefix, config, weights)
|
|
else:
|
|
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
|
|
|
|
|
def load_vision_model(prefix, config, weights):
|
|
if config.model_type == "clip_vision_model":
|
|
from text_generation_server.models.custom_modeling.clip import (
|
|
CLIPVisionTransformer,
|
|
)
|
|
|
|
return CLIPVisionTransformer(
|
|
prefix=f"{prefix}.vision_model", config=config, weights=weights
|
|
)
|
|
if config.model_type == "siglip_vision_model":
|
|
from text_generation_server.models.custom_modeling.siglip import (
|
|
SiglipVisionTransformer,
|
|
)
|
|
|
|
return SiglipVisionTransformer(
|
|
prefix="vision_tower.vision_model", config=config, weights=weights
|
|
)
|
|
else:
|
|
raise RuntimeError(f"Unsupported model type {config.model_type}")
|