xpu add alibi_scope input in varlen_attention in ipex 2.7 while cpu does not. so split the case.

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-06-29 22:15:12 -07:00
parent 9f38d93051
commit d05a5c3f0a
3 changed files with 105 additions and 60 deletions

View File

@ -710,34 +710,41 @@ class MllamaTextCrossAttention(nn.Module):
# )
if SYSTEM == "ipex":
attn_output = torch.empty_like(query_states)
ipex.llm.functional.varlen_attention(
(
query_states.contiguous()
if query_states.device.type == "xpu"
else query_states
),
(
key_states.contiguous()
if key_states.device.type == "xpu"
else key_states
),
(
value_states.contiguous()
if value_states.device.type == "xpu"
else value_states
),
attn_output,
cu_seqlen_q,
cu_seqlen_k,
max_q,
max_k,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
if query_states.device.type == "xpu":
ipex.llm.functional.varlen_attention(
query_states.contiguous(),
key_states.contiguous(),
value_states.contiguous(),
attn_output,
cu_seqlen_q,
cu_seqlen_k,
None,
max_q,
max_k,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
ipex.llm.functional.varlen_attention(
query_states,
key_states,
value_states,
attn_output,
cu_seqlen_q,
cu_seqlen_k,
max_q,
max_k,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
attn_output = flash_attn_2_cuda.varlen_fwd(
query_states,

View File

@ -460,22 +460,41 @@ class Qwen2_5VLAttention(nn.Module):
# execute flash attention
if SYSTEM == "ipex":
attn_output = torch.empty_like(query)
ipex.llm.functional.varlen_attention(
(query.contiguous() if query.device.type == "xpu" else query),
(key.contiguous() if key.device.type == "xpu" else key),
(value.contiguous() if value.device.type == "xpu" else value),
attn_output,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
if query.device.dtype == "xpu":
ipex.llm.functional.varlen_attention(
query.contiguous(),
key.contiguous(),
value.contiguous(),
attn_output,
cu_seqlens,
cu_seqlens,
None,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
ipex.llm.functional.varlen_attention(
query,
key,
value,
attn_output,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
attn_output = flash_attn_2_cuda.varlen_fwd(
query,

View File

@ -130,22 +130,41 @@ class Qwen2VLAttention(nn.Module):
# execute flash attention
if SYSTEM == "ipex":
attn_output = torch.empty_like(query)
ipex.llm.functional.varlen_attention(
(query.contiguous() if query.device.type == "xpu" else query),
(key.contiguous() if key.device.type == "xpu" else key),
(value.contiguous() if value.device.type == "xpu" else value),
attn_output,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
if query.device.type == "xpu":
ipex.llm.functional.varlen_attention(
query.contiguous(),
key.contiguous(),
value.contiguous(),
attn_output,
cu_seqlens,
cu_seqlens,
None,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
ipex.llm.functional.varlen_attention(
query,
key,
value,
attn_output,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
attn_output = flash_attn_2_cuda.varlen_fwd(
query,