From 50ffe00a1a88d778438eafa5cc29c71de142a32f Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 21 Mar 2025 13:44:40 +0000 Subject: [PATCH] Improve attn_implementation --- .../models/transformers_flash_vlm.py | 54 +++---------------- 1 file changed, 6 insertions(+), 48 deletions(-) diff --git a/server/text_generation_server/models/transformers_flash_vlm.py b/server/text_generation_server/models/transformers_flash_vlm.py index aea2b8a81..5056a88ea 100644 --- a/server/text_generation_server/models/transformers_flash_vlm.py +++ b/server/text_generation_server/models/transformers_flash_vlm.py @@ -58,8 +58,7 @@ def tgi_flash_attention_forward( _, num_heads, head_dim = query_states.shape softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale sliding_window = -1 if sliding_window is None else sliding_window - # if module.layer_idx == 0: - # from pdb import set_trace; set_trace() + if cu_seqlen_prefill is not None: if not use_sdpa: attn_output = attention( @@ -136,51 +135,6 @@ def tgi_flash_attention_forward( transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward -# Siglip -transformers.models.siglip.modeling_siglip.SIGLIP_ATTENTION_CLASSES["tgi"] = ( - transformers.models.siglip.modeling_siglip.SIGLIP_ATTENTION_CLASSES["sdpa"] -) - -# Qwen2VL -transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[ - "tgi" -] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[ - "eager" -] -# This needs to be patched in transformers to use ALL_ATTENTION_FUNCTIONS -# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_ATTENTION_CLASSES["tgi"] = tgi_flash_attention_forward - -# Idefics2 -transformers.models.idefics2.modeling_idefics2.IDEFICS_VISION_ATTENTION_CLASSES[ - "tgi" -] = transformers.models.idefics2.modeling_idefics2.IDEFICS_VISION_ATTENTION_CLASSES[ - "eager" -] -transformers.models.idefics2.modeling_idefics2.IDEFICS2_PERCEIVER_ATTENTION_CLASSES[ - "tgi" -] = transformers.models.idefics2.modeling_idefics2.IDEFICS2_PERCEIVER_ATTENTION_CLASSES[ - "eager" -] - -# Idefics3 -transformers.models.idefics3.modeling_idefics3.IDEFICS_VISION_ATTENTION_CLASSES[ - "tgi" -] = transformers.models.idefics3.modeling_idefics3.IDEFICS_VISION_ATTENTION_CLASSES[ - "eager" -] - -# Clip -transformers.models.clip.modeling_clip.CLIP_ATTENTION_CLASSES["tgi"] = ( - transformers.models.clip.modeling_clip.CLIP_ATTENTION_CLASSES["sdpa"] -) - -# Mllama -transformers.models.mllama.modeling_mllama.MLLAMA_VISION_ATTENTION_CLASSES["tgi"] = ( - transformers.models.mllama.modeling_mllama.MLLAMA_VISION_ATTENTION_CLASSES["eager"] -) -# This needs to be patched in transformers to use ALL_ATTENTION_FUNCTIONS -# transformers.models.mllama.modeling_mllama.MLLAMA_TEXT_ATTENTION_CLASSES["tgi"] = tgi_flash_attention_forward -# transformers.models.mllama.modeling_mllama.MLLAMA_CROSS_ATTENTION_CLASSES["tgi"] = tgi_cross_attention_forward # TODO: implement # tgi_cross_attention_forward @@ -240,13 +194,17 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): **processor_kwargs, ) + attn_implementation = { + "text_config": "tgi", + } + model = model_class.from_pretrained( model_id, revision=revision, torch_dtype=dtype, load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, - attn_implementation="tgi", + attn_implementation=attn_implementation, device_map=device if world_size == 1 else None, tp_plan="auto" if world_size > 1 else None, )