fix review comments

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2024-01-30 00:32:28 -08:00 committed by Nicolas Patry
parent 49cd0ce943
commit bc069db165
2 changed files with 9 additions and 2 deletions

View File

@ -86,6 +86,10 @@ def attention(
raise ValueError("`window_size_left` must be > 0 or -1")
if IS_XPU_SYSTEM:
if window_size_left != -1:
raise ValueError(
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
return torch.xpu.varlen_fwd(
q,
k,

View File

@ -999,8 +999,8 @@ try:
# Inplace operation, updating query and key.
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
elif IS_XPU_SYSTEM:
sin = sin.repeat(1, 1, 2).expand(query.shape)
cos = cos.repeat(1, 1, 2).expand(query.shape)
sin = sin.expand(query.shape)
cos = cos.expand(query.shape)
torch.ops.torch_ipex.apply_rotary_embedding_half_qk(query, key, sin, cos, query, key)
else:
raise ValueError(
@ -1122,6 +1122,9 @@ try:
cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids)
if IS_XPU_SYSTEM:
return cos.unsqueeze(1).repeat(1, 1, 2), sin.unsqueeze(1).repeat(1, 1, 2)
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
return cos.unsqueeze(1), sin.unsqueeze(1)