feat: support loading gemma2 as vlm text model

This commit is contained in:
drbh 2024-12-06 10:46:49 -05:00
parent 5df8059037
commit 0f0fe9a998

View File

@ -17,6 +17,12 @@ def load_text_model(prefix, config, weights, name=None):
) )
return FlashGemmaForCausalLM(prefix, config, weights, causal=False) return FlashGemmaForCausalLM(prefix, config, weights, causal=False)
elif config.model_type == "gemma2":
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma2ForCausalLM,
)
return FlashGemma2ForCausalLM(prefix, config, weights)
elif config.model_type == "paligemma": elif config.model_type == "paligemma":
from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM, FlashGemmaForCausalLM,