From 073f79397629e0f49fb449e463ade2829072e85c Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 18 Mar 2025 23:11:01 -0700 Subject: [PATCH] fix phimoe issue Signed-off-by: Wang, Yi A --- .../gaudi/server/text_generation_server/layers/rotary.py | 8 ++++++++ .../server/text_generation_server/models/__init__.py | 3 +++ 2 files changed, 11 insertions(+) diff --git a/backends/gaudi/server/text_generation_server/layers/rotary.py b/backends/gaudi/server/text_generation_server/layers/rotary.py index 5b6cad5c..1f8a6bd1 100644 --- a/backends/gaudi/server/text_generation_server/layers/rotary.py +++ b/backends/gaudi/server/text_generation_server/layers/rotary.py @@ -188,6 +188,7 @@ class PositionRotaryEmbedding(nn.Module): long_inv_freq=long_inv_freq, scaling_factor=scaling_factor, original_max_position_embeddings=original_max_position_embeddings, + max_position_embeddings=config.max_position_embeddings, ) else: raise NotImplementedError( @@ -276,6 +277,7 @@ class SuRotaryEmbedding(PositionRotaryEmbedding): long_inv_freq, scaling_factor, original_max_position_embeddings, + max_position_embeddings, ): super(PositionRotaryEmbedding, self).__init__() self.short_inv_freq = short_inv_freq @@ -288,6 +290,9 @@ class SuRotaryEmbedding(PositionRotaryEmbedding): self._cos_k_cached = None self._sin_k_cached = None self.dynamic_args = None + self._update_cos_sin_cache( + torch.float32, short_inv_freq.device, max_position_embeddings + ) def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, @@ -341,6 +346,9 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding): self._cos_k_cached = None self._sin_k_cached = None self.dynamic_args = None + self._update_cos_sin_cache( + torch.float32, short_inv_freq.device, max_position_embeddings + ) def _update_cos_sin_cache(self, dtype, device, seqlen): if ( diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 926fb57a..7144542f 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -25,6 +25,9 @@ from text_generation_server.models.vlm_causal_lm import VlmCausalLM from text_generation_server.models.custom_modeling.llava_next import ( LlavaNextForConditionalGeneration, ) +from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import ( + PhiMoEConfig, +) # from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch # from text_generation_server.models.custom_modeling.mllama import (