Multi modality fix (#3283)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2025-08-19 15:36:36 +08:00 committed by GitHub
parent 6a2fa83540
commit 5284b5c654
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 107 additions and 61 deletions

View File

@ -710,34 +710,41 @@ class MllamaTextCrossAttention(nn.Module):
# )
if SYSTEM == "ipex":
attn_output = torch.empty_like(query_states)
ipex.llm.functional.varlen_attention(
(
query_states.contiguous()
if query_states.device.type == "xpu"
else query_states
),
(
key_states.contiguous()
if key_states.device.type == "xpu"
else key_states
),
(
value_states.contiguous()
if value_states.device.type == "xpu"
else value_states
),
attn_output,
cu_seqlen_q,
cu_seqlen_k,
max_q,
max_k,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
if query_states.device.type == "xpu":
ipex.llm.functional.varlen_attention(
query_states.contiguous(),
key_states.contiguous(),
value_states.contiguous(),
attn_output,
cu_seqlen_q,
cu_seqlen_k,
None,
max_q,
max_k,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
ipex.llm.functional.varlen_attention(
query_states,
key_states,
value_states,
attn_output,
cu_seqlen_q,
cu_seqlen_k,
max_q,
max_k,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
attn_output = flash_attn_2_cuda.varlen_fwd(
query_states,

View File

@ -460,22 +460,41 @@ class Qwen2_5VLAttention(nn.Module):
# execute flash attention
if SYSTEM == "ipex":
attn_output = torch.empty_like(query)
ipex.llm.functional.varlen_attention(
(query.contiguous() if query.device.type == "xpu" else query),
(key.contiguous() if key.device.type == "xpu" else key),
(value.contiguous() if value.device.type == "xpu" else value),
attn_output,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
if query.device.type == "xpu":
ipex.llm.functional.varlen_attention(
query.contiguous(),
key.contiguous(),
value.contiguous(),
attn_output,
cu_seqlens,
cu_seqlens,
None,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
ipex.llm.functional.varlen_attention(
query,
key,
value,
attn_output,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
attn_output = flash_attn_2_cuda.varlen_fwd(
query,

View File

@ -130,22 +130,41 @@ class Qwen2VLAttention(nn.Module):
# execute flash attention
if SYSTEM == "ipex":
attn_output = torch.empty_like(query)
ipex.llm.functional.varlen_attention(
(query.contiguous() if query.device.type == "xpu" else query),
(key.contiguous() if key.device.type == "xpu" else key),
(value.contiguous() if value.device.type == "xpu" else value),
attn_output,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
if query.device.type == "xpu":
ipex.llm.functional.varlen_attention(
query.contiguous(),
key.contiguous(),
value.contiguous(),
attn_output,
cu_seqlens,
cu_seqlens,
None,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
ipex.llm.functional.varlen_attention(
query,
key,
value,
attn_output,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
attn_output = flash_attn_2_cuda.varlen_fwd(
query,

View File

@ -59,7 +59,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
assert self.image_indices is not None
batch = super().filter(request_ids)
batch = super(VlmCausalLMBatch, self).filter(request_ids)
assert self.image_indices is not None
indices = []
for i, request_id in enumerate(request_ids):
@ -85,6 +85,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
]
else:
batch.cross_attention_states = None
batch.pixel_values = None
return batch
@classmethod