Move the _update_cos_sin_cache into get_cos_sin

If the function is in init of class, the function is invoked layer
number times.

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-06-04 03:00:23 +00:00
parent 70217ac345
commit 8b9a503f8a

View File

@ -36,9 +36,6 @@ class PositionRotaryEmbedding(nn.Module):
self._sin_k_cached = None self._sin_k_cached = None
self.scaling_factor = scaling_factor self.scaling_factor = scaling_factor
self.dynamic_args = None self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, inv_freq.device, max_position_embeddings
)
def forward( def forward(
self, self,
@ -270,7 +267,9 @@ class PositionRotaryEmbedding(nn.Module):
self._sin_cached = torch.sin(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype)
def get_cos_sin(self, position_ids: torch.Tensor): def get_cos_sin(self, position_ids: torch.Tensor):
self._update_cos_sin_cache(
torch.float32, position_ids.device, seqlen=position_ids.shape[-1]
)
cos = torch.index_select(self._cos_cached, 0, position_ids) cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids)
@ -298,9 +297,6 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
self._cos_k_cached = None self._cos_k_cached = None
self._sin_k_cached = None self._sin_k_cached = None
self.dynamic_args = None self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, short_inv_freq.device, max_position_embeddings
)
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
@ -354,9 +350,6 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
self._cos_k_cached = None self._cos_k_cached = None
self._sin_k_cached = None self._sin_k_cached = None
self.dynamic_args = None self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, short_inv_freq.device, max_position_embeddings
)
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
if ( if (