mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix review comments
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
49cd0ce943
commit
bc069db165
@ -86,6 +86,10 @@ def attention(
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
|
||||
if IS_XPU_SYSTEM:
|
||||
if window_size_left != -1:
|
||||
raise ValueError(
|
||||
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||
)
|
||||
return torch.xpu.varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
|
@ -999,8 +999,8 @@ try:
|
||||
# Inplace operation, updating query and key.
|
||||
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
||||
elif IS_XPU_SYSTEM:
|
||||
sin = sin.repeat(1, 1, 2).expand(query.shape)
|
||||
cos = cos.repeat(1, 1, 2).expand(query.shape)
|
||||
sin = sin.expand(query.shape)
|
||||
cos = cos.expand(query.shape)
|
||||
torch.ops.torch_ipex.apply_rotary_embedding_half_qk(query, key, sin, cos, query, key)
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -1122,6 +1122,9 @@ try:
|
||||
|
||||
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||
|
||||
if IS_XPU_SYSTEM:
|
||||
return cos.unsqueeze(1).repeat(1, 1, 2), sin.unsqueeze(1).repeat(1, 1, 2)
|
||||
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user