diff --git a/backends/gaudi/server/text_generation_server/layers/rotary.py b/backends/gaudi/server/text_generation_server/layers/rotary.py index 7e740e5f..d381d4c6 100644 --- a/backends/gaudi/server/text_generation_server/layers/rotary.py +++ b/backends/gaudi/server/text_generation_server/layers/rotary.py @@ -36,9 +36,7 @@ class PositionRotaryEmbedding(nn.Module): self._sin_k_cached = None self.scaling_factor = scaling_factor self.dynamic_args = None - self._update_cos_sin_cache( - torch.float32, inv_freq.device, max_position_embeddings - ) + self.max_position_embeddings = max_position_embeddings def forward( self, @@ -270,7 +268,9 @@ class PositionRotaryEmbedding(nn.Module): self._sin_cached = torch.sin(freqs).to(dtype) def get_cos_sin(self, position_ids: torch.Tensor): - + self._update_cos_sin_cache( + torch.float32, position_ids.device, seqlen=self.max_position_embeddings + ) cos = torch.index_select(self._cos_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids) @@ -298,9 +298,6 @@ class SuRotaryEmbedding(PositionRotaryEmbedding): self._cos_k_cached = None self._sin_k_cached = None self.dynamic_args = None - self._update_cos_sin_cache( - torch.float32, short_inv_freq.device, max_position_embeddings - ) def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, @@ -354,9 +351,6 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding): self._cos_k_cached = None self._sin_k_cached = None self.dynamic_args = None - self._update_cos_sin_cache( - torch.float32, short_inv_freq.device, max_position_embeddings - ) def _update_cos_sin_cache(self, dtype, device, seqlen): if ( @@ -598,6 +592,9 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): position_ids: torch.Tensor, ): slen = position_ids.shape[0] + self._update_cos_sin_cache( + torch.float32, position_ids.device, seqlen=self.max_position_embeddings + ) cos = self._cos_cached[position_ids].gather(1, self._sections[:slen]) sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])