diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index d0185ede..f0e1236d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -64,6 +64,7 @@ class LlamaConfig(PretrainedConfig): pretraining_tp=1, tie_word_embeddings=False, rope_scaling=None, + rope_theta=10000.0, **kwargs, ): self.vocab_size = vocab_size @@ -84,6 +85,7 @@ class LlamaConfig(PretrainedConfig): self.pretraining_tp = pretraining_tp self.use_cache = use_cache self.rope_scaling = rope_scaling + self.rope_theta = rope_theta super().__init__( pad_token_id=pad_token_id, @@ -189,7 +191,7 @@ class FlashLlamaAttention(torch.nn.Module): # config=config, prefix=f"{prefix}.rotary_emb", weights=weights # ) self.rotary_emb = PositionRotaryEmbedding.static( - config=config, dim=self.head_size, base=10000.0, device=weights.device + config=config, dim=self.head_size, base=config.rope_theta, device=weights.device ) self.softmax_scale = self.head_size**-0.5