Multi modality fix (#3283)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2025-08-19 15:36:36 +08:00 committed by GitHub
parent 6a2fa83540
commit 5284b5c654
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 107 additions and 61 deletions

View File

@ -710,34 +710,41 @@ class MllamaTextCrossAttention(nn.Module):
# ) # )
if SYSTEM == "ipex": if SYSTEM == "ipex":
attn_output = torch.empty_like(query_states) attn_output = torch.empty_like(query_states)
ipex.llm.functional.varlen_attention( if query_states.device.type == "xpu":
( ipex.llm.functional.varlen_attention(
query_states.contiguous() query_states.contiguous(),
if query_states.device.type == "xpu" key_states.contiguous(),
else query_states value_states.contiguous(),
), attn_output,
( cu_seqlen_q,
key_states.contiguous() cu_seqlen_k,
if key_states.device.type == "xpu" None,
else key_states max_q,
), max_k,
( 0.0,
value_states.contiguous() self.softmax_scale,
if value_states.device.type == "xpu" False,
else value_states causal,
), False,
attn_output, None,
cu_seqlen_q, )
cu_seqlen_k, else:
max_q, ipex.llm.functional.varlen_attention(
max_k, query_states,
0.0, key_states,
self.softmax_scale, value_states,
False, attn_output,
causal, cu_seqlen_q,
False, cu_seqlen_k,
None, max_q,
) max_k,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else: else:
attn_output = flash_attn_2_cuda.varlen_fwd( attn_output = flash_attn_2_cuda.varlen_fwd(
query_states, query_states,

View File

@ -460,22 +460,41 @@ class Qwen2_5VLAttention(nn.Module):
# execute flash attention # execute flash attention
if SYSTEM == "ipex": if SYSTEM == "ipex":
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
ipex.llm.functional.varlen_attention( if query.device.type == "xpu":
(query.contiguous() if query.device.type == "xpu" else query), ipex.llm.functional.varlen_attention(
(key.contiguous() if key.device.type == "xpu" else key), query.contiguous(),
(value.contiguous() if value.device.type == "xpu" else value), key.contiguous(),
attn_output, value.contiguous(),
cu_seqlens, attn_output,
cu_seqlens, cu_seqlens,
max_seqlen, cu_seqlens,
max_seqlen, None,
0.0, max_seqlen,
self.softmax_scale, max_seqlen,
False, 0.0,
causal, self.softmax_scale,
False, False,
None, causal,
) False,
None,
)
else:
ipex.llm.functional.varlen_attention(
query,
key,
value,
attn_output,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else: else:
attn_output = flash_attn_2_cuda.varlen_fwd( attn_output = flash_attn_2_cuda.varlen_fwd(
query, query,

View File

@ -130,22 +130,41 @@ class Qwen2VLAttention(nn.Module):
# execute flash attention # execute flash attention
if SYSTEM == "ipex": if SYSTEM == "ipex":
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
ipex.llm.functional.varlen_attention( if query.device.type == "xpu":
(query.contiguous() if query.device.type == "xpu" else query), ipex.llm.functional.varlen_attention(
(key.contiguous() if key.device.type == "xpu" else key), query.contiguous(),
(value.contiguous() if value.device.type == "xpu" else value), key.contiguous(),
attn_output, value.contiguous(),
cu_seqlens, attn_output,
cu_seqlens, cu_seqlens,
max_seqlen, cu_seqlens,
max_seqlen, None,
0.0, max_seqlen,
self.softmax_scale, max_seqlen,
False, 0.0,
causal, self.softmax_scale,
False, False,
None, causal,
) False,
None,
)
else:
ipex.llm.functional.varlen_attention(
query,
key,
value,
attn_output,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else: else:
attn_output = flash_attn_2_cuda.varlen_fwd( attn_output = flash_attn_2_cuda.varlen_fwd(
query, query,

View File

@ -59,7 +59,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]): def filter(self, request_ids: List[int]):
assert self.image_indices is not None assert self.image_indices is not None
batch = super().filter(request_ids) batch = super(VlmCausalLMBatch, self).filter(request_ids)
assert self.image_indices is not None assert self.image_indices is not None
indices = [] indices = []
for i, request_id in enumerate(request_ids): for i, request_id in enumerate(request_ids):
@ -85,6 +85,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
] ]
else: else:
batch.cross_attention_states = None batch.cross_attention_states = None
batch.pixel_values = None
return batch return batch
@classmethod @classmethod