From cae28dcbf172fd8e8a6b5d8a06c3b96094736150 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 1 Aug 2024 15:06:09 +0000 Subject: [PATCH] fix: prefer version check over test op and avoid window_size_left if not flash attn2 --- .../layers/attention/cuda.py | 49 +++---------------- .../text_generation_server/models/__init__.py | 6 ++- .../custom_modeling/flash_mistral_modeling.py | 11 ++--- .../custom_modeling/flash_mixtral_modeling.py | 11 ++--- .../custom_modeling/flash_qwen2_modeling.py | 11 ++--- .../flash_starcoder2_modeling.py | 11 ++--- 6 files changed, 31 insertions(+), 68 deletions(-) diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index dc99641b..f86ce1f4 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -3,7 +3,6 @@ from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE from text_generation_server.layers.attention import Seqlen from typing import Optional -import warnings major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 @@ -172,43 +171,11 @@ def paged_attention( try: + if major <= 8: + raise ImportError("Flash Attention V2 requires CUDA 11.0 or higher") + import flash_attn_2_cuda - # try forwarding to see if it works with all dummy inputs - batch_size = 1 - num_heads = 1 - head_dim = 1 - seqlen = 1 - - try: - flash_attn_2_cuda.varlen_fwd( - torch.zeros(batch_size, num_heads, seqlen, head_dim), # q - torch.zeros(batch_size, num_heads, seqlen, head_dim), # k - torch.zeros(batch_size, num_heads, seqlen, head_dim), # v - None, # out (optional) - torch.zeros(batch_size + 1, dtype=torch.int32), # cu_seqlens_q - torch.zeros(batch_size + 1, dtype=torch.int32), # cu_seqlens_k - None, # alibi_slopes (optional) - None, # q_padded (optional) - None, # k_padded (optional) - None, # v_padded (optional) - seqlen, # max_seqlen_q - seqlen, # max_seqlen_k - 1.0, # softmax_scale - 0.0, # softmax_lse (default value) - False, # is_causal - True, # return_softmax - -1, # window_size_left - -1, # window_size_right - 0.0, # softmax_softcap - False, # deterministic - None, # rng_state (optional) - ) - except RuntimeError as e: - raise ImportError( - "Flash Attention V2 is not supported on this machine. " f"Error: {e}" - ) from e - V2 = True except ImportError: try: @@ -289,10 +256,9 @@ else: window_size_left=-1, softcap=None, ): - if window_size_left != -1: - warnings.warn( - "window_size_left is only available with flash attn v2. It will be ignored.", - UserWarning, + if window_size_left is not None and window_size_left != -1: + raise NotImplementedError( + "window_size_left is only available with flash attn v2" ) if softcap is not None: @@ -341,6 +307,3 @@ else: 0, None, ) - - -SUPPORTS_WINDOWING = True diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 3dc24159..40d3c07b 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -484,10 +484,14 @@ def get_model( ) sliding_window = config_dict.get("sliding_window", -1) + is_max_input_within_sliding_window = ( + max_input_tokens <= sliding_window if max_input_tokens is not None else False + ) + if ( (sliding_window is not None and sliding_window != -1) and not SUPPORTS_WINDOWING - and max_input_tokens > sliding_window + and is_max_input_within_sliding_window ): raise ValueError( f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 3d8f6bf4..c24328c2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -31,6 +31,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + SUPPORTS_WINDOWING, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -110,9 +111,7 @@ class MistralConfig(PretrainedConfig): class MistralAttention(torch.nn.Module): def __init__(self, prefix: str, config, weights, layer_id): super().__init__() - self.max_past = ( - config.sliding_window if config.sliding_window is not None else -1 - ) + self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size if hasattr(config, "head_dim"): @@ -487,10 +486,10 @@ class FlashMistralForCausalLM(torch.nn.Module): ), weights=weights, ) - self.max_past = config.sliding_window + self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None self.max_past_tensor = ( - torch.tensor(config.sliding_window, device=weights.device) - if self.max_past is not None + torch.tensor(self.max_past, device=weights.device) + if self.max_past else None ) diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 7cdca553..75f68d93 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -35,6 +35,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + SUPPORTS_WINDOWING, ) from text_generation_server.layers import ( FastLinear, @@ -195,9 +196,7 @@ class MixtralAttention(torch.nn.Module): weights, ): super().__init__() - self.max_past = ( - config.sliding_window if config.sliding_window is not None else -1 - ) + self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads @@ -615,10 +614,10 @@ class FlashMixtralForCausalLM(torch.nn.Module): prefix="lm_head" if not prefix else f"{prefix}.lm_head", weights=weights, ) - self.max_past = config.sliding_window + self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None self.max_past_tensor = ( - torch.tensor(config.sliding_window, device=weights.device) - if self.max_past is not None + torch.tensor(self.max_past, device=weights.device) + if self.max_past else None ) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index e357a287..a7962f73 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -9,6 +9,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + SUPPORTS_WINDOWING, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -56,9 +57,7 @@ class Qwen2Attention(torch.nn.Module): weights, ): super().__init__() - self.max_past = ( - config.sliding_window if config.sliding_window is not None else -1 - ) + self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads @@ -349,10 +348,10 @@ class Qwen2ForCausalLM(torch.nn.Module): prefix=f"{prefix}.{suffix}" if prefix else suffix, weights=weights, ) - self.max_past = config.sliding_window + self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None self.max_past_tensor = ( - torch.tensor(config.sliding_window, device=weights.device) - if self.max_past is not None + torch.tensor(self.max_past, device=weights.device) + if self.max_past else None ) diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index cfa891d4..709bb6b9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + SUPPORTS_WINDOWING, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -161,9 +162,7 @@ class Starcoder2Attention(torch.nn.Module): weights, ): super().__init__() - self.max_past = ( - config.sliding_window if config.sliding_window is not None else -1 - ) + self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads @@ -509,10 +508,10 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): weights=weights, ) - self.max_past = config.sliding_window + self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None self.max_past_tensor = ( - torch.tensor(config.sliding_window, device=weights.device) - if self.max_past is not None + torch.tensor(self.max_past, device=weights.device) + if self.max_past else None )