mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 06:42:10 +00:00
fix: fix logic if sliding window key is not present in config (#1352)
This commit is contained in:
parent
9b56d3fbf5
commit
1b1bfa49b0
@ -281,9 +281,10 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == "mistral":
|
if model_type == "mistral":
|
||||||
if (config_dict["sliding_window"] is None and FLASH_ATTENTION) or (
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
config_dict["sliding_window"] > 0 and HAS_FLASH_ATTN_V2_CUDA
|
if (
|
||||||
):
|
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
|
||||||
|
) or HAS_FLASH_ATTN_V2_CUDA:
|
||||||
return FlashMistral(
|
return FlashMistral(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -293,9 +294,10 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == "mixtral":
|
if model_type == "mixtral":
|
||||||
if (config_dict["sliding_window"] is None and FLASH_ATTENTION) or (
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
config_dict["sliding_window"] > 0 and HAS_FLASH_ATTN_V2_CUDA
|
if (
|
||||||
):
|
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
|
||||||
|
) or HAS_FLASH_ATTN_V2_CUDA:
|
||||||
return FlashMixtral(
|
return FlashMixtral(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
|
@ -60,7 +60,7 @@ class MistralConfig(PretrainedConfig):
|
|||||||
pretraining_tp=1,
|
pretraining_tp=1,
|
||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
rope_theta=10000.0,
|
rope_theta=10000.0,
|
||||||
sliding_window=4096,
|
sliding_window=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
@ -72,7 +72,7 @@ class MixtralConfig(PretrainedConfig):
|
|||||||
pretraining_tp=1,
|
pretraining_tp=1,
|
||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
rope_theta=10000.0,
|
rope_theta=10000.0,
|
||||||
sliding_window=4096,
|
sliding_window=None,
|
||||||
num_experts_per_tok=2,
|
num_experts_per_tok=2,
|
||||||
num_local_experts=8,
|
num_local_experts=8,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -33,7 +33,7 @@ class Model(ABC):
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.sliding_window = sliding_window
|
self.sliding_window = sliding_window if sliding_window != -1 else None
|
||||||
|
|
||||||
if speculate is None:
|
if speculate is None:
|
||||||
speculate = get_speculate()
|
speculate = get_speculate()
|
||||||
|
Loading…
Reference in New Issue
Block a user