diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 49281f0a..6e50b7f8 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -208,13 +208,13 @@ try: ) from text_generation_server.models.transformers_flash_vlm import ( TransformersFlashVlmCausalLM, - TransformersQwen2VlmCausalLM, TransformersGemma3VlmCausalLM, ) except ImportError as e: log_master(logger.warning, f"Could not import Flash Transformers Backend: {e}") FLASH_TRANSFORMERS_BACKEND = False +# TODO: remove this, it's a temporary for testing the FLASH_TRANSFORMERS_BACKEND FLASH_ATTENTION = False @@ -1506,18 +1506,21 @@ def get_model( trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) - elif FLASH_TRANSFORMERS_BACKEND: - from transformers import Qwen2VLForConditionalGeneration as Qwen2VLModel + # TODO: Uncomment when transformers is refactored + # elif FLASH_TRANSFORMERS_BACKEND: + # from transformers import Qwen2VLForConditionalGeneration as Qwen2VLModel - return TransformersQwen2VlmCausalLM.fallback( - model_id, - Qwen2VLModel, - revision, - quantize=quantize, - speculator=speculator, - dtype=torch.bfloat16, - trust_remote_code=trust_remote_code, - ) + # return TransformersQwen2VlmCausalLM.fallback( + # model_id, + # Qwen2VLModel, + # revision, + # quantize=quantize, + # speculator=speculator, + # dtype=torch.bfloat16, + # trust_remote_code=trust_remote_code, + # ) + else: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Qwen2_VL")) if model_type == QWEN2_5_VL: if FLASH_ATTENTION: return VlmCausalLM( @@ -1534,19 +1537,21 @@ def get_model( config_class=Qwen2_5_VLConfig, processor_class=Qwen2_5_VLProcessor, ) - elif FLASH_TRANSFORMERS_BACKEND: - - return TransformersQwen2VlmCausalLM.fallback( - model_id, - Qwen2VLModel, - revision, - quantize=quantize, - speculator=speculator, - dtype=torch.bfloat16, - trust_remote_code=trust_remote_code, - config_class=Qwen2_5_VLConfig, - processor_class=Qwen2_5_VLProcessor, - ) + # TODO: Uncomment when transformers is refactored + # elif FLASH_TRANSFORMERS_BACKEND: + # return TransformersQwen2VlmCausalLM.fallback( + # model_id, + # Qwen2VLModel, + # revision, + # quantize=quantize, + # speculator=speculator, + # dtype=torch.bfloat16, + # trust_remote_code=trust_remote_code, + # config_class=Qwen2_5_VLConfig, + # processor_class=Qwen2_5_VLProcessor, + # ) + else: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Qwen2_5_VL")) if model_type == MLLAMA: if FLASH_ATTENTION: return MllamaCausalLM( @@ -1561,19 +1566,20 @@ def get_model( trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) - elif FLASH_TRANSFORMERS_BACKEND: - from transformers import MllamaForConditionalGeneration as MllamaModel + # TODO: Uncomment when transformers is refactored and cross attn is added + # elif FLASH_TRANSFORMERS_BACKEND: + # from transformers import MllamaForConditionalGeneration as MllamaModel - return TransformersFlashVlmCausalLM.fallback( - model_id, - MllamaModel, - revision, - quantize=quantize, - speculator=speculator, - dtype=torch.bfloat16, - trust_remote_code=trust_remote_code, - batch_class=MllamaCausalLMBatch, - ) + # return TransformersFlashVlmCausalLM.fallback( + # model_id, + # MllamaModel, + # revision, + # quantize=quantize, + # speculator=speculator, + # dtype=torch.bfloat16, + # trust_remote_code=trust_remote_code, + # batch_class=MllamaCausalLMBatch, + # ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama")) if model_type == IDEFICS2: diff --git a/server/text_generation_server/models/transformers_flash_vlm.py b/server/text_generation_server/models/transformers_flash_vlm.py index 5056a88e..ccb2cb76 100644 --- a/server/text_generation_server/models/transformers_flash_vlm.py +++ b/server/text_generation_server/models/transformers_flash_vlm.py @@ -27,6 +27,12 @@ REPLICATED_ATTENTION_MODELS = [ ] +# # Qwen2VL +# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[ +# "tgi" +# ] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[ +# "eager" +# ] def tgi_flash_attention_forward( module, query_states: torch.Tensor,