From a05f3849e46b5d95fa5669174318288a319c8190 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Mon, 23 Sep 2024 20:48:43 -0700 Subject: [PATCH] do not set sliding_window if SUPPORTS_WINDOWING is false Signed-off-by: Wang, Yi A --- .../models/flash_causal_lm.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a2834962..1a7bf09c 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -47,6 +47,7 @@ from text_generation_server.models.globals import ( get_adapter_to_index, ) from text_generation_server.layers.attention import Seqlen +from text_generation_server.layers.attention import SUPPORTS_WINDOWING from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.quantization import get_loader @@ -992,6 +993,21 @@ class FlashCausalLM(Model): ) prefix = "" + if getattr(config, "sliding_window", None) is not None and SUPPORTS_WINDOWING: + set_sliding_window(config.sliding_window) + else: + config.sliding_window = None + + text_config = getattr(config, "text_config", None) + if text_config: + if ( + getattr(text_config, "sliding_window", None) is not None + and SUPPORTS_WINDOWING + ): + set_sliding_window(text_config.sliding_window) + else: + text_config.sliding_window = None + model = model_class(prefix, config, weights) torch.distributed.barrier(group=self.process_group) @@ -1000,11 +1016,6 @@ class FlashCausalLM(Model): if text_config is not None: config = text_config - if getattr(config, "sliding_window", None) is not None: - set_sliding_window(config.sliding_window) - else: - config.sliding_window = None - self.num_layers = config.num_hidden_layers self.num_heads = config.num_attention_heads // self.process_group.size() # Validation is done in the model itself