diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index df192cc1..a8c8f75e 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -86,6 +86,10 @@ def attention( raise ValueError("`window_size_left` must be > 0 or -1") if IS_XPU_SYSTEM: + if window_size_left != -1: + raise ValueError( + f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." + ) return torch.xpu.varlen_fwd( q, k, diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 04ae7724..c66a8d2c 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -999,8 +999,8 @@ try: # Inplace operation, updating query and key. pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True) elif IS_XPU_SYSTEM: - sin = sin.repeat(1, 1, 2).expand(query.shape) - cos = cos.repeat(1, 1, 2).expand(query.shape) + sin = sin.expand(query.shape) + cos = cos.expand(query.shape) torch.ops.torch_ipex.apply_rotary_embedding_half_qk(query, key, sin, cos, query, key) else: raise ValueError( @@ -1122,6 +1122,9 @@ try: cos = torch.index_select(self._cos_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids) + + if IS_XPU_SYSTEM: + return cos.unsqueeze(1).repeat(1, 1, 2), sin.unsqueeze(1).repeat(1, 1, 2) # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. return cos.unsqueeze(1), sin.unsqueeze(1)