From 1b1bfa49b04448ab1afac2d9bf790fff7bd1871b Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 15 Dec 2023 14:56:17 +0100 Subject: [PATCH] fix: fix logic if sliding window key is not present in config (#1352) --- server/text_generation_server/models/__init__.py | 14 ++++++++------ .../custom_modeling/flash_mistral_modeling.py | 2 +- .../custom_modeling/flash_mixtral_modeling.py | 2 +- server/text_generation_server/models/model.py | 2 +- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 1bbff16a..39d1d58e 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -281,9 +281,10 @@ def get_model( ) if model_type == "mistral": - if (config_dict["sliding_window"] is None and FLASH_ATTENTION) or ( - config_dict["sliding_window"] > 0 and HAS_FLASH_ATTN_V2_CUDA - ): + sliding_window = config_dict.get("sliding_window", -1) + if ( + (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION + ) or HAS_FLASH_ATTN_V2_CUDA: return FlashMistral( model_id, revision, @@ -293,9 +294,10 @@ def get_model( ) if model_type == "mixtral": - if (config_dict["sliding_window"] is None and FLASH_ATTENTION) or ( - config_dict["sliding_window"] > 0 and HAS_FLASH_ATTN_V2_CUDA - ): + sliding_window = config_dict.get("sliding_window", -1) + if ( + (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION + ) or HAS_FLASH_ATTN_V2_CUDA: return FlashMixtral( model_id, revision, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index c85624f3..0fc4e1b3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -60,7 +60,7 @@ class MistralConfig(PretrainedConfig): pretraining_tp=1, tie_word_embeddings=False, rope_theta=10000.0, - sliding_window=4096, + sliding_window=None, **kwargs, ): self.vocab_size = vocab_size diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index b468d09b..61488ec4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -72,7 +72,7 @@ class MixtralConfig(PretrainedConfig): pretraining_tp=1, tie_word_embeddings=False, rope_theta=10000.0, - sliding_window=4096, + sliding_window=None, num_experts_per_tok=2, num_local_experts=8, **kwargs, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index cb358672..cec9eafa 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -33,7 +33,7 @@ class Model(ABC): self.device = device self.rank = rank self.world_size = world_size - self.sliding_window = sliding_window + self.sliding_window = sliding_window if sliding_window != -1 else None if speculate is None: speculate = get_speculate()