diff --git a/backends/gaudi/server/text_generation_server/layers/rotary.py b/backends/gaudi/server/text_generation_server/layers/rotary.py index 7e740e5f..84671c44 100644 --- a/backends/gaudi/server/text_generation_server/layers/rotary.py +++ b/backends/gaudi/server/text_generation_server/layers/rotary.py @@ -36,9 +36,6 @@ class PositionRotaryEmbedding(nn.Module): self._sin_k_cached = None self.scaling_factor = scaling_factor self.dynamic_args = None - self._update_cos_sin_cache( - torch.float32, inv_freq.device, max_position_embeddings - ) def forward( self, @@ -270,7 +267,9 @@ class PositionRotaryEmbedding(nn.Module): self._sin_cached = torch.sin(freqs).to(dtype) def get_cos_sin(self, position_ids: torch.Tensor): - + self._update_cos_sin_cache( + torch.float32, position_ids.device, seqlen=position_ids.shape[-1] + ) cos = torch.index_select(self._cos_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids) @@ -298,9 +297,6 @@ class SuRotaryEmbedding(PositionRotaryEmbedding): self._cos_k_cached = None self._sin_k_cached = None self.dynamic_args = None - self._update_cos_sin_cache( - torch.float32, short_inv_freq.device, max_position_embeddings - ) def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, @@ -354,9 +350,6 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding): self._cos_k_cached = None self._sin_k_cached = None self.dynamic_args = None - self._update_cos_sin_cache( - torch.float32, short_inv_freq.device, max_position_embeddings - ) def _update_cos_sin_cache(self, dtype, device, seqlen): if (