Add comments for support of models

This commit is contained in:
Mohit Sharma 2025-03-21 14:11:15 +00:00
parent 50ffe00a1a
commit b5bac0dd2d
2 changed files with 49 additions and 37 deletions

View File

@ -208,13 +208,13 @@ try:
)
from text_generation_server.models.transformers_flash_vlm import (
TransformersFlashVlmCausalLM,
TransformersQwen2VlmCausalLM,
TransformersGemma3VlmCausalLM,
)
except ImportError as e:
log_master(logger.warning, f"Could not import Flash Transformers Backend: {e}")
FLASH_TRANSFORMERS_BACKEND = False
# TODO: remove this, it's a temporary for testing the FLASH_TRANSFORMERS_BACKEND
FLASH_ATTENTION = False
@ -1506,18 +1506,21 @@ def get_model(
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif FLASH_TRANSFORMERS_BACKEND:
from transformers import Qwen2VLForConditionalGeneration as Qwen2VLModel
# TODO: Uncomment when transformers is refactored
# elif FLASH_TRANSFORMERS_BACKEND:
# from transformers import Qwen2VLForConditionalGeneration as Qwen2VLModel
return TransformersQwen2VlmCausalLM.fallback(
model_id,
Qwen2VLModel,
revision,
quantize=quantize,
speculator=speculator,
dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
)
# return TransformersQwen2VlmCausalLM.fallback(
# model_id,
# Qwen2VLModel,
# revision,
# quantize=quantize,
# speculator=speculator,
# dtype=torch.bfloat16,
# trust_remote_code=trust_remote_code,
# )
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Qwen2_VL"))
if model_type == QWEN2_5_VL:
if FLASH_ATTENTION:
return VlmCausalLM(
@ -1534,19 +1537,21 @@ def get_model(
config_class=Qwen2_5_VLConfig,
processor_class=Qwen2_5_VLProcessor,
)
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersQwen2VlmCausalLM.fallback(
model_id,
Qwen2VLModel,
revision,
quantize=quantize,
speculator=speculator,
dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
config_class=Qwen2_5_VLConfig,
processor_class=Qwen2_5_VLProcessor,
)
# TODO: Uncomment when transformers is refactored
# elif FLASH_TRANSFORMERS_BACKEND:
# return TransformersQwen2VlmCausalLM.fallback(
# model_id,
# Qwen2VLModel,
# revision,
# quantize=quantize,
# speculator=speculator,
# dtype=torch.bfloat16,
# trust_remote_code=trust_remote_code,
# config_class=Qwen2_5_VLConfig,
# processor_class=Qwen2_5_VLProcessor,
# )
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Qwen2_5_VL"))
if model_type == MLLAMA:
if FLASH_ATTENTION:
return MllamaCausalLM(
@ -1561,19 +1566,20 @@ def get_model(
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif FLASH_TRANSFORMERS_BACKEND:
from transformers import MllamaForConditionalGeneration as MllamaModel
# TODO: Uncomment when transformers is refactored and cross attn is added
# elif FLASH_TRANSFORMERS_BACKEND:
# from transformers import MllamaForConditionalGeneration as MllamaModel
return TransformersFlashVlmCausalLM.fallback(
model_id,
MllamaModel,
revision,
quantize=quantize,
speculator=speculator,
dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
batch_class=MllamaCausalLMBatch,
)
# return TransformersFlashVlmCausalLM.fallback(
# model_id,
# MllamaModel,
# revision,
# quantize=quantize,
# speculator=speculator,
# dtype=torch.bfloat16,
# trust_remote_code=trust_remote_code,
# batch_class=MllamaCausalLMBatch,
# )
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
if model_type == IDEFICS2:

View File

@ -27,6 +27,12 @@ REPLICATED_ATTENTION_MODELS = [
]
# # Qwen2VL
# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
# "tgi"
# ] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
# "eager"
# ]
def tgi_flash_attention_forward(
module,
query_states: torch.Tensor,