Revert "fix: expand logic for different hardware"

This reverts commit 7c09eae0a0.
This commit is contained in:
drbh 2025-02-11 17:14:02 +01:00
parent 7c09eae0a0
commit b7250f0473

View File

@ -570,47 +570,6 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
.to(inv_freq.device)
)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
):
# The new rotation expects the input to be arranged as two halves.
# Here we assume that query and key have shape (..., 2 * rot) where
rot = cos.shape[-1] // 2
if SYSTEM == "cuda":
# For CUDA we construct the companion tensor q2 from the input.
# This q2 is defined as the concatenation of -second_half and first_half.
q2 = torch.cat([-query[..., rot:], query[..., :rot]], dim=-1)
k2 = torch.cat([-key[..., rot:], key[..., :rot]], dim=-1)
# Now apply the rotary embedding in-place on both query and key.
rotary_emb.apply_rotary(query, q2, cos, sin, query, q2, True)
rotary_emb.apply_rotary(key, k2, cos, sin, key, k2, True)
elif SYSTEM == "rocm":
# rotate the query and key before applying the rotary embedding
query = torch.cat([-query[..., rot:], query[..., :rot]], dim=-1)
key = torch.cat([-key[..., rot:], key[..., :rot]], dim=-1)
head_size = query.shape[-1]
ops.rotary_embedding(query, key, head_size, cos, sin, True)
elif SYSTEM == "ipex":
# rotate the query and key before applying the rotary embedding
query = torch.cat([-query[..., rot:], query[..., :rot]], dim=-1)
key = torch.cat([-key[..., rot:], key[..., :rot]], dim=-1)
ipex.llm.functional.rotary_embedding(
query, key, sin, cos, query.size(-1), True
)
else:
raise ValueError(
"Your system seems not to be supported. Please check your install or open an issue "
"at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)
def _update_cos_sin_cache(
self, dtype: torch.dtype, device: torch.device, seqlen: int
):
@ -639,7 +598,4 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])
cos = torch.cat([cos, cos], dim=-1)
sin = torch.cat([sin, sin], dim=-1)
return cos, sin