mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 03:14:53 +00:00
xpu add alibi_scope input in varlen_attention in ipex 2.7 while cpu does not. so split the case.
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
9f38d93051
commit
d05a5c3f0a
@ -710,22 +710,29 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
# )
|
# )
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex":
|
||||||
attn_output = torch.empty_like(query_states)
|
attn_output = torch.empty_like(query_states)
|
||||||
|
if query_states.device.type == "xpu":
|
||||||
ipex.llm.functional.varlen_attention(
|
ipex.llm.functional.varlen_attention(
|
||||||
(
|
query_states.contiguous(),
|
||||||
query_states.contiguous()
|
key_states.contiguous(),
|
||||||
if query_states.device.type == "xpu"
|
value_states.contiguous(),
|
||||||
else query_states
|
attn_output,
|
||||||
),
|
cu_seqlen_q,
|
||||||
(
|
cu_seqlen_k,
|
||||||
key_states.contiguous()
|
None,
|
||||||
if key_states.device.type == "xpu"
|
max_q,
|
||||||
else key_states
|
max_k,
|
||||||
),
|
0.0,
|
||||||
(
|
self.softmax_scale,
|
||||||
value_states.contiguous()
|
False,
|
||||||
if value_states.device.type == "xpu"
|
causal,
|
||||||
else value_states
|
False,
|
||||||
),
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ipex.llm.functional.varlen_attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlen_q,
|
cu_seqlen_q,
|
||||||
cu_seqlen_k,
|
cu_seqlen_k,
|
||||||
|
@ -460,10 +460,29 @@ class Qwen2_5VLAttention(nn.Module):
|
|||||||
# execute flash attention
|
# execute flash attention
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex":
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
|
if query.device.dtype == "xpu":
|
||||||
ipex.llm.functional.varlen_attention(
|
ipex.llm.functional.varlen_attention(
|
||||||
(query.contiguous() if query.device.type == "xpu" else query),
|
query.contiguous(),
|
||||||
(key.contiguous() if key.device.type == "xpu" else key),
|
key.contiguous(),
|
||||||
(value.contiguous() if value.device.type == "xpu" else value),
|
value.contiguous(),
|
||||||
|
attn_output,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
None,
|
||||||
|
max_seqlen,
|
||||||
|
max_seqlen,
|
||||||
|
0.0,
|
||||||
|
self.softmax_scale,
|
||||||
|
False,
|
||||||
|
causal,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ipex.llm.functional.varlen_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
|
@ -130,10 +130,29 @@ class Qwen2VLAttention(nn.Module):
|
|||||||
# execute flash attention
|
# execute flash attention
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex":
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
|
if query.device.type == "xpu":
|
||||||
ipex.llm.functional.varlen_attention(
|
ipex.llm.functional.varlen_attention(
|
||||||
(query.contiguous() if query.device.type == "xpu" else query),
|
query.contiguous(),
|
||||||
(key.contiguous() if key.device.type == "xpu" else key),
|
key.contiguous(),
|
||||||
(value.contiguous() if value.device.type == "xpu" else value),
|
value.contiguous(),
|
||||||
|
attn_output,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
None,
|
||||||
|
max_seqlen,
|
||||||
|
max_seqlen,
|
||||||
|
0.0,
|
||||||
|
self.softmax_scale,
|
||||||
|
False,
|
||||||
|
causal,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ipex.llm.functional.varlen_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
|
Loading…
Reference in New Issue
Block a user