fix: remove get_cos_sin_hack dev function

This commit is contained in:
David Holtz 2024-10-28 02:20:00 +00:00 committed by drbh
parent 22fdf9344f
commit ec933282b2

View File

@ -277,32 +277,6 @@ class PositionRotaryEmbedding(nn.Module):
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
return cos.unsqueeze(1), sin.unsqueeze(1) return cos.unsqueeze(1), sin.unsqueeze(1)
def get_cos_sin_hack(
self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
):
# TODO: avoid always computing, use the cache and update it if necessary
inv_freq_expanded = (
self.inv_freq[None, None, :, None]
.float()
.expand(3, position_ids.shape[1], -1, 1)
)
position_ids_expanded = position_ids[
:, :, None, :
].float() # shape (3, bs, 1, positions)
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(
2, 3
)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype)
sin = emb.sin().to(dtype)
# Update cached values
self._cos_cached = cos
self._sin_cached = sin
return cos, sin
class SuRotaryEmbedding(PositionRotaryEmbedding): class SuRotaryEmbedding(PositionRotaryEmbedding):
def __init__( def __init__(