fix: fix logic if sliding window key is not present in config (#1352)

This commit is contained in:
OlivierDehaene 2023-12-15 14:56:17 +01:00 committed by GitHub
parent 9b56d3fbf5
commit 1b1bfa49b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 11 additions and 9 deletions

View File

@ -281,9 +281,10 @@ def get_model(
) )
if model_type == "mistral": if model_type == "mistral":
if (config_dict["sliding_window"] is None and FLASH_ATTENTION) or ( sliding_window = config_dict.get("sliding_window", -1)
config_dict["sliding_window"] > 0 and HAS_FLASH_ATTN_V2_CUDA if (
): (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
) or HAS_FLASH_ATTN_V2_CUDA:
return FlashMistral( return FlashMistral(
model_id, model_id,
revision, revision,
@ -293,9 +294,10 @@ def get_model(
) )
if model_type == "mixtral": if model_type == "mixtral":
if (config_dict["sliding_window"] is None and FLASH_ATTENTION) or ( sliding_window = config_dict.get("sliding_window", -1)
config_dict["sliding_window"] > 0 and HAS_FLASH_ATTN_V2_CUDA if (
): (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
) or HAS_FLASH_ATTN_V2_CUDA:
return FlashMixtral( return FlashMixtral(
model_id, model_id,
revision, revision,

View File

@ -60,7 +60,7 @@ class MistralConfig(PretrainedConfig):
pretraining_tp=1, pretraining_tp=1,
tie_word_embeddings=False, tie_word_embeddings=False,
rope_theta=10000.0, rope_theta=10000.0,
sliding_window=4096, sliding_window=None,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size

View File

@ -72,7 +72,7 @@ class MixtralConfig(PretrainedConfig):
pretraining_tp=1, pretraining_tp=1,
tie_word_embeddings=False, tie_word_embeddings=False,
rope_theta=10000.0, rope_theta=10000.0,
sliding_window=4096, sliding_window=None,
num_experts_per_tok=2, num_experts_per_tok=2,
num_local_experts=8, num_local_experts=8,
**kwargs, **kwargs,

View File

@ -33,7 +33,7 @@ class Model(ABC):
self.device = device self.device = device
self.rank = rank self.rank = rank
self.world_size = world_size self.world_size = world_size
self.sliding_window = sliding_window self.sliding_window = sliding_window if sliding_window != -1 else None
if speculate is None: if speculate is None:
speculate = get_speculate() speculate = get_speculate()