text-generation-inference/server/text_generation_server/models/custom_modeling/vlm.py
drbh da5ab46705
Improve vlm support (add idefics3 support) (#2437)
* feat: expand vlm support and add image token logic and tests

* fix: avoid unused perceiver config

* feat: integrate image tokens into inputs embeds

* feat: add simple idefics3 test

* feat: update docs, image token logic and weight names

* fix: improve image processing

* feat: improve prefix for idefics3

* fix: bump idefics3 tests and snapshots

* fix: improve text model loading

* feat: consolidate changes with existing vlms and add support and test for smolvlm

* fix: create new idefic3 file, simplify logic and adjust llama weight loading

* fix: lint with ruff

* fix: clean up idefics 3 and improve prefix handling

* fix: improve typing

* fix: improve prompt_split_image with ref to original impl

* fix: adjust ruff lints and small refactors

* fix: adjust FlashLlamaModel prefix logic
2025-01-09 10:35:32 -05:00

55 lines
2.1 KiB
Python

def load_text_model(prefix, config, weights, name=None):
if config.model_type == "llama":
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
return FlashLlamaForCausalLM(prefix, config, weights, name=name)
elif config.model_type == "mistral":
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
return FlashMistralForCausalLM(prefix, config, weights, name=name)
elif config.model_type == "gemma":
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
)
return FlashGemmaForCausalLM(prefix, config, weights, causal=False)
elif config.model_type == "gemma2":
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma2ForCausalLM,
)
return FlashGemma2ForCausalLM(prefix, config, weights)
elif config.model_type == "paligemma":
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
)
return FlashGemmaForCausalLM(prefix, config, weights)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
def load_vision_model(prefix, config, weights):
if config.model_type == "clip_vision_model":
from text_generation_server.models.custom_modeling.clip import (
CLIPVisionTransformer,
)
return CLIPVisionTransformer(
prefix=f"{prefix}.vision_model", config=config, weights=weights
)
if config.model_type == "siglip_vision_model":
from text_generation_server.models.custom_modeling.siglip import (
SiglipVisionTransformer,
)
return SiglipVisionTransformer(
prefix="vision_tower.vision_model", config=config, weights=weights
)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")