mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 14:22:08 +00:00
fix phimoe issue
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
5cd1c93cad
commit
073f793976
@ -188,6 +188,7 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
long_inv_freq=long_inv_freq,
|
long_inv_freq=long_inv_freq,
|
||||||
scaling_factor=scaling_factor,
|
scaling_factor=scaling_factor,
|
||||||
original_max_position_embeddings=original_max_position_embeddings,
|
original_max_position_embeddings=original_max_position_embeddings,
|
||||||
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -276,6 +277,7 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
|
|||||||
long_inv_freq,
|
long_inv_freq,
|
||||||
scaling_factor,
|
scaling_factor,
|
||||||
original_max_position_embeddings,
|
original_max_position_embeddings,
|
||||||
|
max_position_embeddings,
|
||||||
):
|
):
|
||||||
super(PositionRotaryEmbedding, self).__init__()
|
super(PositionRotaryEmbedding, self).__init__()
|
||||||
self.short_inv_freq = short_inv_freq
|
self.short_inv_freq = short_inv_freq
|
||||||
@ -288,6 +290,9 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
|
|||||||
self._cos_k_cached = None
|
self._cos_k_cached = None
|
||||||
self._sin_k_cached = None
|
self._sin_k_cached = None
|
||||||
self.dynamic_args = None
|
self.dynamic_args = None
|
||||||
|
self._update_cos_sin_cache(
|
||||||
|
torch.float32, short_inv_freq.device, max_position_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
# Reset the tables if the sequence length has changed,
|
# Reset the tables if the sequence length has changed,
|
||||||
@ -341,6 +346,9 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
|
|||||||
self._cos_k_cached = None
|
self._cos_k_cached = None
|
||||||
self._sin_k_cached = None
|
self._sin_k_cached = None
|
||||||
self.dynamic_args = None
|
self.dynamic_args = None
|
||||||
|
self._update_cos_sin_cache(
|
||||||
|
torch.float32, short_inv_freq.device, max_position_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
if (
|
if (
|
||||||
|
@ -25,6 +25,9 @@ from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
|||||||
from text_generation_server.models.custom_modeling.llava_next import (
|
from text_generation_server.models.custom_modeling.llava_next import (
|
||||||
LlavaNextForConditionalGeneration,
|
LlavaNextForConditionalGeneration,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
|
||||||
|
PhiMoEConfig,
|
||||||
|
)
|
||||||
|
|
||||||
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
||||||
# from text_generation_server.models.custom_modeling.mllama import (
|
# from text_generation_server.models.custom_modeling.mllama import (
|
||||||
|
Loading…
Reference in New Issue
Block a user