diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index f38f6859..eb0bdd4f 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -577,14 +577,20 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): cos: torch.Tensor, sin: torch.Tensor, ): - # rotate half the sequence length - rot = cos.shape[-1] // 2 - q2 = torch.cat([-query[..., rot:], query[..., :rot]], dim=-1) - k2 = torch.cat([-key[..., rot:], key[..., :rot]], dim=-1) - # apply the rotation - rotary_emb.apply_rotary(query, q2, cos, sin, query, q2, True) - rotary_emb.apply_rotary(key, k2, cos, sin, key, k2, True) + if SYSTEM == "ipex": + ipex.llm.functional.rotary_embedding( + query, key, sin, cos, query.size(-1), True + ) + else: + # rotate half the sequence length + rot = cos.shape[-1] // 2 + q2 = torch.cat([-query[..., rot:], query[..., :rot]], dim=-1) + k2 = torch.cat([-key[..., rot:], key[..., :rot]], dim=-1) + + # apply the rotation + rotary_emb.apply_rotary(query, q2, cos, sin, query, q2, True) + rotary_emb.apply_rotary(key, k2, cos, sin, key, k2, True) def _update_cos_sin_cache( self, dtype: torch.dtype, device: torch.device, seqlen: int