fix phimoe issue

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-03-18 23:11:01 -07:00
parent 5cd1c93cad
commit 073f793976
2 changed files with 11 additions and 0 deletions

View File

@ -188,6 +188,7 @@ class PositionRotaryEmbedding(nn.Module):
long_inv_freq=long_inv_freq, long_inv_freq=long_inv_freq,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
original_max_position_embeddings=original_max_position_embeddings, original_max_position_embeddings=original_max_position_embeddings,
max_position_embeddings=config.max_position_embeddings,
) )
else: else:
raise NotImplementedError( raise NotImplementedError(
@ -276,6 +277,7 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
long_inv_freq, long_inv_freq,
scaling_factor, scaling_factor,
original_max_position_embeddings, original_max_position_embeddings,
max_position_embeddings,
): ):
super(PositionRotaryEmbedding, self).__init__() super(PositionRotaryEmbedding, self).__init__()
self.short_inv_freq = short_inv_freq self.short_inv_freq = short_inv_freq
@ -288,6 +290,9 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
self._cos_k_cached = None self._cos_k_cached = None
self._sin_k_cached = None self._sin_k_cached = None
self.dynamic_args = None self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, short_inv_freq.device, max_position_embeddings
)
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
@ -341,6 +346,9 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
self._cos_k_cached = None self._cos_k_cached = None
self._sin_k_cached = None self._sin_k_cached = None
self.dynamic_args = None self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, short_inv_freq.device, max_position_embeddings
)
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
if ( if (

View File

@ -25,6 +25,9 @@ from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.custom_modeling.llava_next import ( from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration, LlavaNextForConditionalGeneration,
) )
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
PhiMoEConfig,
)
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch # from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
# from text_generation_server.models.custom_modeling.mllama import ( # from text_generation_server.models.custom_modeling.mllama import (