mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 06:42:10 +00:00
Improve attn_implementation
This commit is contained in:
parent
b41faae318
commit
50ffe00a1a
@ -58,8 +58,7 @@ def tgi_flash_attention_forward(
|
|||||||
_, num_heads, head_dim = query_states.shape
|
_, num_heads, head_dim = query_states.shape
|
||||||
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
|
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
|
||||||
sliding_window = -1 if sliding_window is None else sliding_window
|
sliding_window = -1 if sliding_window is None else sliding_window
|
||||||
# if module.layer_idx == 0:
|
|
||||||
# from pdb import set_trace; set_trace()
|
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
if not use_sdpa:
|
if not use_sdpa:
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
@ -136,51 +135,6 @@ def tgi_flash_attention_forward(
|
|||||||
|
|
||||||
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
|
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
|
||||||
|
|
||||||
# Siglip
|
|
||||||
transformers.models.siglip.modeling_siglip.SIGLIP_ATTENTION_CLASSES["tgi"] = (
|
|
||||||
transformers.models.siglip.modeling_siglip.SIGLIP_ATTENTION_CLASSES["sdpa"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Qwen2VL
|
|
||||||
transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
|
|
||||||
"tgi"
|
|
||||||
] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
|
|
||||||
"eager"
|
|
||||||
]
|
|
||||||
# This needs to be patched in transformers to use ALL_ATTENTION_FUNCTIONS
|
|
||||||
# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_ATTENTION_CLASSES["tgi"] = tgi_flash_attention_forward
|
|
||||||
|
|
||||||
# Idefics2
|
|
||||||
transformers.models.idefics2.modeling_idefics2.IDEFICS_VISION_ATTENTION_CLASSES[
|
|
||||||
"tgi"
|
|
||||||
] = transformers.models.idefics2.modeling_idefics2.IDEFICS_VISION_ATTENTION_CLASSES[
|
|
||||||
"eager"
|
|
||||||
]
|
|
||||||
transformers.models.idefics2.modeling_idefics2.IDEFICS2_PERCEIVER_ATTENTION_CLASSES[
|
|
||||||
"tgi"
|
|
||||||
] = transformers.models.idefics2.modeling_idefics2.IDEFICS2_PERCEIVER_ATTENTION_CLASSES[
|
|
||||||
"eager"
|
|
||||||
]
|
|
||||||
|
|
||||||
# Idefics3
|
|
||||||
transformers.models.idefics3.modeling_idefics3.IDEFICS_VISION_ATTENTION_CLASSES[
|
|
||||||
"tgi"
|
|
||||||
] = transformers.models.idefics3.modeling_idefics3.IDEFICS_VISION_ATTENTION_CLASSES[
|
|
||||||
"eager"
|
|
||||||
]
|
|
||||||
|
|
||||||
# Clip
|
|
||||||
transformers.models.clip.modeling_clip.CLIP_ATTENTION_CLASSES["tgi"] = (
|
|
||||||
transformers.models.clip.modeling_clip.CLIP_ATTENTION_CLASSES["sdpa"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mllama
|
|
||||||
transformers.models.mllama.modeling_mllama.MLLAMA_VISION_ATTENTION_CLASSES["tgi"] = (
|
|
||||||
transformers.models.mllama.modeling_mllama.MLLAMA_VISION_ATTENTION_CLASSES["eager"]
|
|
||||||
)
|
|
||||||
# This needs to be patched in transformers to use ALL_ATTENTION_FUNCTIONS
|
|
||||||
# transformers.models.mllama.modeling_mllama.MLLAMA_TEXT_ATTENTION_CLASSES["tgi"] = tgi_flash_attention_forward
|
|
||||||
# transformers.models.mllama.modeling_mllama.MLLAMA_CROSS_ATTENTION_CLASSES["tgi"] = tgi_cross_attention_forward
|
|
||||||
|
|
||||||
# TODO: implement
|
# TODO: implement
|
||||||
# tgi_cross_attention_forward
|
# tgi_cross_attention_forward
|
||||||
@ -240,13 +194,17 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
**processor_kwargs,
|
**processor_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
attn_implementation = {
|
||||||
|
"text_config": "tgi",
|
||||||
|
}
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
model = model_class.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
attn_implementation="tgi",
|
attn_implementation=attn_implementation,
|
||||||
device_map=device if world_size == 1 else None,
|
device_map=device if world_size == 1 else None,
|
||||||
tp_plan="auto" if world_size > 1 else None,
|
tp_plan="auto" if world_size > 1 else None,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user