Add comments for support of models

This commit is contained in:
Mohit Sharma 2025-03-21 14:11:15 +00:00
parent 50ffe00a1a
commit b5bac0dd2d
2 changed files with 49 additions and 37 deletions

View File

@ -208,13 +208,13 @@ try:
) )
from text_generation_server.models.transformers_flash_vlm import ( from text_generation_server.models.transformers_flash_vlm import (
TransformersFlashVlmCausalLM, TransformersFlashVlmCausalLM,
TransformersQwen2VlmCausalLM,
TransformersGemma3VlmCausalLM, TransformersGemma3VlmCausalLM,
) )
except ImportError as e: except ImportError as e:
log_master(logger.warning, f"Could not import Flash Transformers Backend: {e}") log_master(logger.warning, f"Could not import Flash Transformers Backend: {e}")
FLASH_TRANSFORMERS_BACKEND = False FLASH_TRANSFORMERS_BACKEND = False
# TODO: remove this, it's a temporary for testing the FLASH_TRANSFORMERS_BACKEND
FLASH_ATTENTION = False FLASH_ATTENTION = False
@ -1506,18 +1506,21 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
elif FLASH_TRANSFORMERS_BACKEND: # TODO: Uncomment when transformers is refactored
from transformers import Qwen2VLForConditionalGeneration as Qwen2VLModel # elif FLASH_TRANSFORMERS_BACKEND:
# from transformers import Qwen2VLForConditionalGeneration as Qwen2VLModel
return TransformersQwen2VlmCausalLM.fallback( # return TransformersQwen2VlmCausalLM.fallback(
model_id, # model_id,
Qwen2VLModel, # Qwen2VLModel,
revision, # revision,
quantize=quantize, # quantize=quantize,
speculator=speculator, # speculator=speculator,
dtype=torch.bfloat16, # dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, # trust_remote_code=trust_remote_code,
) # )
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Qwen2_VL"))
if model_type == QWEN2_5_VL: if model_type == QWEN2_5_VL:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return VlmCausalLM( return VlmCausalLM(
@ -1534,19 +1537,21 @@ def get_model(
config_class=Qwen2_5_VLConfig, config_class=Qwen2_5_VLConfig,
processor_class=Qwen2_5_VLProcessor, processor_class=Qwen2_5_VLProcessor,
) )
elif FLASH_TRANSFORMERS_BACKEND: # TODO: Uncomment when transformers is refactored
# elif FLASH_TRANSFORMERS_BACKEND:
return TransformersQwen2VlmCausalLM.fallback( # return TransformersQwen2VlmCausalLM.fallback(
model_id, # model_id,
Qwen2VLModel, # Qwen2VLModel,
revision, # revision,
quantize=quantize, # quantize=quantize,
speculator=speculator, # speculator=speculator,
dtype=torch.bfloat16, # dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, # trust_remote_code=trust_remote_code,
config_class=Qwen2_5_VLConfig, # config_class=Qwen2_5_VLConfig,
processor_class=Qwen2_5_VLProcessor, # processor_class=Qwen2_5_VLProcessor,
) # )
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Qwen2_5_VL"))
if model_type == MLLAMA: if model_type == MLLAMA:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return MllamaCausalLM( return MllamaCausalLM(
@ -1561,19 +1566,20 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
elif FLASH_TRANSFORMERS_BACKEND: # TODO: Uncomment when transformers is refactored and cross attn is added
from transformers import MllamaForConditionalGeneration as MllamaModel # elif FLASH_TRANSFORMERS_BACKEND:
# from transformers import MllamaForConditionalGeneration as MllamaModel
return TransformersFlashVlmCausalLM.fallback( # return TransformersFlashVlmCausalLM.fallback(
model_id, # model_id,
MllamaModel, # MllamaModel,
revision, # revision,
quantize=quantize, # quantize=quantize,
speculator=speculator, # speculator=speculator,
dtype=torch.bfloat16, # dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, # trust_remote_code=trust_remote_code,
batch_class=MllamaCausalLMBatch, # batch_class=MllamaCausalLMBatch,
) # )
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
if model_type == IDEFICS2: if model_type == IDEFICS2:

View File

@ -27,6 +27,12 @@ REPLICATED_ATTENTION_MODELS = [
] ]
# # Qwen2VL
# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
# "tgi"
# ] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
# "eager"
# ]
def tgi_flash_attention_forward( def tgi_flash_attention_forward(
module, module,
query_states: torch.Tensor, query_states: torch.Tensor,