diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 1bbff16a4..39d1d58ec 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -281,9 +281,10 @@ def get_model( ) if model_type == "mistral": - if (config_dict["sliding_window"] is None and FLASH_ATTENTION) or ( - config_dict["sliding_window"] > 0 and HAS_FLASH_ATTN_V2_CUDA - ): + sliding_window = config_dict.get("sliding_window", -1) + if ( + (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION + ) or HAS_FLASH_ATTN_V2_CUDA: return FlashMistral( model_id, revision, @@ -293,9 +294,10 @@ def get_model( ) if model_type == "mixtral": - if (config_dict["sliding_window"] is None and FLASH_ATTENTION) or ( - config_dict["sliding_window"] > 0 and HAS_FLASH_ATTN_V2_CUDA - ): + sliding_window = config_dict.get("sliding_window", -1) + if ( + (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION + ) or HAS_FLASH_ATTN_V2_CUDA: return FlashMixtral( model_id, revision, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index c85624f3a..0fc4e1b39 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -60,7 +60,7 @@ class MistralConfig(PretrainedConfig): pretraining_tp=1, tie_word_embeddings=False, rope_theta=10000.0, - sliding_window=4096, + sliding_window=None, **kwargs, ): self.vocab_size = vocab_size diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index b468d09b2..61488ec4f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -72,7 +72,7 @@ class MixtralConfig(PretrainedConfig): pretraining_tp=1, tie_word_embeddings=False, rope_theta=10000.0, - sliding_window=4096, + sliding_window=None, num_experts_per_tok=2, num_local_experts=8, **kwargs, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index cb358672e..cec9eafa8 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -33,7 +33,7 @@ class Model(ABC): self.device = device self.rank = rank self.world_size = world_size - self.sliding_window = sliding_window + self.sliding_window = sliding_window if sliding_window != -1 else None if speculate is None: speculate = get_speculate()