Improve attn_implementation

This commit is contained in:
Mohit Sharma 2025-03-21 13:44:40 +00:00
parent b41faae318
commit 50ffe00a1a

View File

@ -58,8 +58,7 @@ def tgi_flash_attention_forward(
_, num_heads, head_dim = query_states.shape _, num_heads, head_dim = query_states.shape
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
sliding_window = -1 if sliding_window is None else sliding_window sliding_window = -1 if sliding_window is None else sliding_window
# if module.layer_idx == 0:
# from pdb import set_trace; set_trace()
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
if not use_sdpa: if not use_sdpa:
attn_output = attention( attn_output = attention(
@ -136,51 +135,6 @@ def tgi_flash_attention_forward(
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
# Siglip
transformers.models.siglip.modeling_siglip.SIGLIP_ATTENTION_CLASSES["tgi"] = (
transformers.models.siglip.modeling_siglip.SIGLIP_ATTENTION_CLASSES["sdpa"]
)
# Qwen2VL
transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
"tgi"
] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
"eager"
]
# This needs to be patched in transformers to use ALL_ATTENTION_FUNCTIONS
# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_ATTENTION_CLASSES["tgi"] = tgi_flash_attention_forward
# Idefics2
transformers.models.idefics2.modeling_idefics2.IDEFICS_VISION_ATTENTION_CLASSES[
"tgi"
] = transformers.models.idefics2.modeling_idefics2.IDEFICS_VISION_ATTENTION_CLASSES[
"eager"
]
transformers.models.idefics2.modeling_idefics2.IDEFICS2_PERCEIVER_ATTENTION_CLASSES[
"tgi"
] = transformers.models.idefics2.modeling_idefics2.IDEFICS2_PERCEIVER_ATTENTION_CLASSES[
"eager"
]
# Idefics3
transformers.models.idefics3.modeling_idefics3.IDEFICS_VISION_ATTENTION_CLASSES[
"tgi"
] = transformers.models.idefics3.modeling_idefics3.IDEFICS_VISION_ATTENTION_CLASSES[
"eager"
]
# Clip
transformers.models.clip.modeling_clip.CLIP_ATTENTION_CLASSES["tgi"] = (
transformers.models.clip.modeling_clip.CLIP_ATTENTION_CLASSES["sdpa"]
)
# Mllama
transformers.models.mllama.modeling_mllama.MLLAMA_VISION_ATTENTION_CLASSES["tgi"] = (
transformers.models.mllama.modeling_mllama.MLLAMA_VISION_ATTENTION_CLASSES["eager"]
)
# This needs to be patched in transformers to use ALL_ATTENTION_FUNCTIONS
# transformers.models.mllama.modeling_mllama.MLLAMA_TEXT_ATTENTION_CLASSES["tgi"] = tgi_flash_attention_forward
# transformers.models.mllama.modeling_mllama.MLLAMA_CROSS_ATTENTION_CLASSES["tgi"] = tgi_cross_attention_forward
# TODO: implement # TODO: implement
# tgi_cross_attention_forward # tgi_cross_attention_forward
@ -240,13 +194,17 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
**processor_kwargs, **processor_kwargs,
) )
attn_implementation = {
"text_config": "tgi",
}
model = model_class.from_pretrained( model = model_class.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
attn_implementation="tgi", attn_implementation=attn_implementation,
device_map=device if world_size == 1 else None, device_map=device if world_size == 1 else None,
tp_plan="auto" if world_size > 1 else None, tp_plan="auto" if world_size > 1 else None,
) )