From afc0fb5adf97b614eac80712fbe4bf353d8e4a14 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 2 Aug 2024 19:01:58 +0000 Subject: [PATCH] fix: simplify changes and revert model changes --- server/text_generation_server/models/__init__.py | 2 +- .../models/custom_modeling/flash_mistral_modeling.py | 11 ++++++----- .../models/custom_modeling/flash_mixtral_modeling.py | 11 ++++++----- .../models/custom_modeling/flash_qwen2_modeling.py | 11 ++++++----- .../custom_modeling/flash_starcoder2_modeling.py | 11 ++++++----- 5 files changed, 25 insertions(+), 21 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 62687432..ae791ef8 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -484,7 +484,7 @@ def get_model( ) sliding_window = config_dict.get("sliding_window", -1) - if max_input_tokens <= sliding_window: + if max_input_tokens is not None and max_input_tokens <= sliding_window: sliding_window = -1 if ( diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index c24328c2..3d8f6bf4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -31,7 +31,6 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, - SUPPORTS_WINDOWING, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -111,7 +110,9 @@ class MistralConfig(PretrainedConfig): class MistralAttention(torch.nn.Module): def __init__(self, prefix: str, config, weights, layer_id): super().__init__() - self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None + self.max_past = ( + config.sliding_window if config.sliding_window is not None else -1 + ) self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size if hasattr(config, "head_dim"): @@ -486,10 +487,10 @@ class FlashMistralForCausalLM(torch.nn.Module): ), weights=weights, ) - self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None + self.max_past = config.sliding_window self.max_past_tensor = ( - torch.tensor(self.max_past, device=weights.device) - if self.max_past + torch.tensor(config.sliding_window, device=weights.device) + if self.max_past is not None else None ) diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 75f68d93..7cdca553 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -35,7 +35,6 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, - SUPPORTS_WINDOWING, ) from text_generation_server.layers import ( FastLinear, @@ -196,7 +195,9 @@ class MixtralAttention(torch.nn.Module): weights, ): super().__init__() - self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None + self.max_past = ( + config.sliding_window if config.sliding_window is not None else -1 + ) self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads @@ -614,10 +615,10 @@ class FlashMixtralForCausalLM(torch.nn.Module): prefix="lm_head" if not prefix else f"{prefix}.lm_head", weights=weights, ) - self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None + self.max_past = config.sliding_window self.max_past_tensor = ( - torch.tensor(self.max_past, device=weights.device) - if self.max_past + torch.tensor(config.sliding_window, device=weights.device) + if self.max_past is not None else None ) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index a7962f73..e357a287 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -9,7 +9,6 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, - SUPPORTS_WINDOWING, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -57,7 +56,9 @@ class Qwen2Attention(torch.nn.Module): weights, ): super().__init__() - self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None + self.max_past = ( + config.sliding_window if config.sliding_window is not None else -1 + ) self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads @@ -348,10 +349,10 @@ class Qwen2ForCausalLM(torch.nn.Module): prefix=f"{prefix}.{suffix}" if prefix else suffix, weights=weights, ) - self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None + self.max_past = config.sliding_window self.max_past_tensor = ( - torch.tensor(self.max_past, device=weights.device) - if self.max_past + torch.tensor(config.sliding_window, device=weights.device) + if self.max_past is not None else None ) diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 709bb6b9..cfa891d4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -30,7 +30,6 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, - SUPPORTS_WINDOWING, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -162,7 +161,9 @@ class Starcoder2Attention(torch.nn.Module): weights, ): super().__init__() - self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None + self.max_past = ( + config.sliding_window if config.sliding_window is not None else -1 + ) self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads @@ -508,10 +509,10 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): weights=weights, ) - self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None + self.max_past = config.sliding_window self.max_past_tensor = ( - torch.tensor(self.max_past, device=weights.device) - if self.max_past + torch.tensor(config.sliding_window, device=weights.device) + if self.max_past is not None else None )