From b392362e9e3e6cb8afbd60cb37ef8b6569d6b3d1 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Thu, 10 Oct 2024 23:02:56 -0700 Subject: [PATCH] direct return in clamp like rocm Signed-off-by: Wang, Yi A --- .../layers/attention/common.py | 2 +- .../models/flash_causal_lm.py | 22 +++++-------------- 2 files changed, 7 insertions(+), 17 deletions(-) diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py index d6e512c0..e4b1a781 100644 --- a/server/text_generation_server/layers/attention/common.py +++ b/server/text_generation_server/layers/attention/common.py @@ -66,7 +66,7 @@ else: max_k: int def clamp(self, max): - if SYSTEM == "rocm": + if SYSTEM == "rocm" or SYSTEM == "ipex": return self raise NotImplementedError("Not implemented seqlen for paged") return Seqlen(torch.clamp(self.input_lengths, max=max)) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 1a573b00..33fe30a8 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -46,7 +46,7 @@ from text_generation_server.models.globals import ( TGI_WIGGLE_ROOM, get_adapter_to_index, ) -from text_generation_server.layers.attention import KVCache, Seqlen, SUPPORTS_WINDOWING +from text_generation_server.layers.attention import KVCache, Seqlen from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.quantization import get_loader @@ -993,21 +993,6 @@ class FlashCausalLM(Model): ) prefix = "" - if getattr(config, "sliding_window", None) is not None and SUPPORTS_WINDOWING: - set_sliding_window(config.sliding_window) - else: - config.sliding_window = None - - text_config = getattr(config, "text_config", None) - if text_config: - if ( - getattr(text_config, "sliding_window", None) is not None - and SUPPORTS_WINDOWING - ): - set_sliding_window(text_config.sliding_window) - else: - text_config.sliding_window = None - model = model_class(prefix, config, weights) torch.distributed.barrier(group=self.process_group) @@ -1016,6 +1001,11 @@ class FlashCausalLM(Model): if text_config is not None: config = text_config + if getattr(config, "sliding_window", None) is not None: + set_sliding_window(config.sliding_window) + else: + config.sliding_window = None + self.num_layers = config.num_hidden_layers self.num_heads = config.num_attention_heads // self.process_group.size() # Validation is done in the model itself