[gaudi] Move the _update_cos_sin_cache into get_cos_sin (#3254)

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
Yuan Wu 2025-06-13 04:31:11 +08:00 committed by GitHub
parent 613b8dd647
commit 25fdc5f03c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -36,9 +36,7 @@ class PositionRotaryEmbedding(nn.Module):
self._sin_k_cached = None self._sin_k_cached = None
self.scaling_factor = scaling_factor self.scaling_factor = scaling_factor
self.dynamic_args = None self.dynamic_args = None
self._update_cos_sin_cache( self.max_position_embeddings = max_position_embeddings
torch.float32, inv_freq.device, max_position_embeddings
)
def forward( def forward(
self, self,
@ -270,7 +268,9 @@ class PositionRotaryEmbedding(nn.Module):
self._sin_cached = torch.sin(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype)
def get_cos_sin(self, position_ids: torch.Tensor): def get_cos_sin(self, position_ids: torch.Tensor):
self._update_cos_sin_cache(
torch.float32, position_ids.device, seqlen=self.max_position_embeddings
)
cos = torch.index_select(self._cos_cached, 0, position_ids) cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids)
@ -298,9 +298,6 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
self._cos_k_cached = None self._cos_k_cached = None
self._sin_k_cached = None self._sin_k_cached = None
self.dynamic_args = None self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, short_inv_freq.device, max_position_embeddings
)
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
@ -354,9 +351,6 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
self._cos_k_cached = None self._cos_k_cached = None
self._sin_k_cached = None self._sin_k_cached = None
self.dynamic_args = None self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, short_inv_freq.device, max_position_embeddings
)
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
if ( if (
@ -598,6 +592,9 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
position_ids: torch.Tensor, position_ids: torch.Tensor,
): ):
slen = position_ids.shape[0] slen = position_ids.shape[0]
self._update_cos_sin_cache(
torch.float32, position_ids.device, seqlen=self.max_position_embeddings
)
cos = self._cos_cached[position_ids].gather(1, self._sections[:slen]) cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
sin = self._sin_cached[position_ids].gather(1, self._sections[:slen]) sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])