mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 06:42:10 +00:00
Improve attn_implementation
This commit is contained in:
parent
b41faae318
commit
50ffe00a1a
@ -58,8 +58,7 @@ def tgi_flash_attention_forward(
|
||||
_, num_heads, head_dim = query_states.shape
|
||||
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
|
||||
sliding_window = -1 if sliding_window is None else sliding_window
|
||||
# if module.layer_idx == 0:
|
||||
# from pdb import set_trace; set_trace()
|
||||
|
||||
if cu_seqlen_prefill is not None:
|
||||
if not use_sdpa:
|
||||
attn_output = attention(
|
||||
@ -136,51 +135,6 @@ def tgi_flash_attention_forward(
|
||||
|
||||
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
|
||||
|
||||
# Siglip
|
||||
transformers.models.siglip.modeling_siglip.SIGLIP_ATTENTION_CLASSES["tgi"] = (
|
||||
transformers.models.siglip.modeling_siglip.SIGLIP_ATTENTION_CLASSES["sdpa"]
|
||||
)
|
||||
|
||||
# Qwen2VL
|
||||
transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
|
||||
"tgi"
|
||||
] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
|
||||
"eager"
|
||||
]
|
||||
# This needs to be patched in transformers to use ALL_ATTENTION_FUNCTIONS
|
||||
# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_ATTENTION_CLASSES["tgi"] = tgi_flash_attention_forward
|
||||
|
||||
# Idefics2
|
||||
transformers.models.idefics2.modeling_idefics2.IDEFICS_VISION_ATTENTION_CLASSES[
|
||||
"tgi"
|
||||
] = transformers.models.idefics2.modeling_idefics2.IDEFICS_VISION_ATTENTION_CLASSES[
|
||||
"eager"
|
||||
]
|
||||
transformers.models.idefics2.modeling_idefics2.IDEFICS2_PERCEIVER_ATTENTION_CLASSES[
|
||||
"tgi"
|
||||
] = transformers.models.idefics2.modeling_idefics2.IDEFICS2_PERCEIVER_ATTENTION_CLASSES[
|
||||
"eager"
|
||||
]
|
||||
|
||||
# Idefics3
|
||||
transformers.models.idefics3.modeling_idefics3.IDEFICS_VISION_ATTENTION_CLASSES[
|
||||
"tgi"
|
||||
] = transformers.models.idefics3.modeling_idefics3.IDEFICS_VISION_ATTENTION_CLASSES[
|
||||
"eager"
|
||||
]
|
||||
|
||||
# Clip
|
||||
transformers.models.clip.modeling_clip.CLIP_ATTENTION_CLASSES["tgi"] = (
|
||||
transformers.models.clip.modeling_clip.CLIP_ATTENTION_CLASSES["sdpa"]
|
||||
)
|
||||
|
||||
# Mllama
|
||||
transformers.models.mllama.modeling_mllama.MLLAMA_VISION_ATTENTION_CLASSES["tgi"] = (
|
||||
transformers.models.mllama.modeling_mllama.MLLAMA_VISION_ATTENTION_CLASSES["eager"]
|
||||
)
|
||||
# This needs to be patched in transformers to use ALL_ATTENTION_FUNCTIONS
|
||||
# transformers.models.mllama.modeling_mllama.MLLAMA_TEXT_ATTENTION_CLASSES["tgi"] = tgi_flash_attention_forward
|
||||
# transformers.models.mllama.modeling_mllama.MLLAMA_CROSS_ATTENTION_CLASSES["tgi"] = tgi_cross_attention_forward
|
||||
|
||||
# TODO: implement
|
||||
# tgi_cross_attention_forward
|
||||
@ -240,13 +194,17 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
||||
**processor_kwargs,
|
||||
)
|
||||
|
||||
attn_implementation = {
|
||||
"text_config": "tgi",
|
||||
}
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
attn_implementation="tgi",
|
||||
attn_implementation=attn_implementation,
|
||||
device_map=device if world_size == 1 else None,
|
||||
tp_plan="auto" if world_size > 1 else None,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user