From 4220423d7614d89de6434d11c7e6add9bfb5bdee Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 7 Jun 2024 08:55:01 +0000 Subject: [PATCH] fix sliding window --- launcher/src/main.rs | 8 ++ server/text_generation_server/cli.py | 2 + .../layers/attention/rocm.py | 4 +- .../layers/attention/xpu.py | 2 +- .../text_generation_server/models/__init__.py | 73 +++++++++++++++++-- server/text_generation_server/server.py | 2 + 6 files changed, 82 insertions(+), 9 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 3d8a7ed6..0acd4ac4 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -476,6 +476,7 @@ fn shard_manager( rope_factor: Option, max_total_tokens: usize, max_batch_size: Option, + max_batch_prefill_tokens: u32, otlp_endpoint: Option, log_level: LevelFilter, status_sender: mpsc::Sender, @@ -548,6 +549,10 @@ fn shard_manager( shard_args.push(otlp_endpoint); } + // In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter. + shard_args.push("--max-batch-prefill-tokens".to_string()); + shard_args.push(max_batch_prefill_tokens.to_string()); + // Copy current process env let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); @@ -1004,6 +1009,7 @@ fn spawn_shards( args: &Args, cuda_graphs: Vec, max_total_tokens: usize, + max_batch_prefill_tokens: u32, max_log_level: LevelFilter, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, @@ -1061,6 +1067,7 @@ fn spawn_shards( rope_factor, max_total_tokens, max_batch_size, + max_batch_prefill_tokens, otlp_endpoint, max_log_level, status_sender, @@ -1535,6 +1542,7 @@ fn main() -> Result<(), LauncherError> { &args, cuda_graphs, max_total_tokens, + max_batch_prefill_tokens, max_log_level, shutdown.clone(), &shutdown_receiver, diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 16375ecd..b62b57c0 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -41,6 +41,7 @@ def serve( logger_level: str = "INFO", json_output: bool = False, otlp_endpoint: Optional[str] = None, + max_batch_prefill_tokens: Optional[int] = None, ): if sharded: assert ( @@ -97,6 +98,7 @@ def serve( dtype, trust_remote_code, uds_path, + max_batch_prefill_tokens, ) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 535810aa..2bee4b51 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -169,7 +169,7 @@ if ENGINE == "ck": ): if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") - if window_size_left != -1: + if window_size_left != -1 and q.shape[0] > window_size_left: raise ValueError( f"ROCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." ) @@ -204,7 +204,7 @@ elif ENGINE == "triton": window_size_left=-1, causal=True, ): - if window_size_left != -1: + if window_size_left != -1 and q.shape[0] > window_size_left: raise ValueError( f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." ) diff --git a/server/text_generation_server/layers/attention/xpu.py b/server/text_generation_server/layers/attention/xpu.py index d9a096f9..99c4699b 100644 --- a/server/text_generation_server/layers/attention/xpu.py +++ b/server/text_generation_server/layers/attention/xpu.py @@ -14,7 +14,7 @@ def attention( softmax_scale, window_size_left=-1, ): - if window_size_left != -1: + if window_size_left != -1 and q.shape[0] > window_size_left: raise ValueError( f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." ) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index d9b59de4..3f488206 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -24,6 +24,8 @@ from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.gpt_neox import GPTNeoxSharded from text_generation_server.models.phi import Phi +from text_generation_server.utils.import_utils import SYSTEM + # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. torch.backends.cuda.matmul.allow_tf32 = True @@ -48,6 +50,8 @@ __all__ = [ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." +SLIDING_WINDOW_MESSAGE = "The backend {} does not support sliding window attention. TGI webserver was started max_batch_prefill_tokens={} larger than sliding_window={}. To use this model with the {} backend, please launch TGI with the argument `--max-batch-prefill-tokens` smaller than {}." + FLASH_ATTENTION = True try: @@ -257,6 +261,7 @@ def get_model( speculate: Optional[int], dtype: Optional[str], trust_remote_code: bool, + max_batch_prefill_tokens: int, ) -> Model: global FLASH_ATTENTION if dtype is None: @@ -410,9 +415,10 @@ def get_model( "Sharding is currently not supported with `exl2` quantization" ) sliding_window = config_dict.get("sliding_window", -1) + if sliding_window != -1 and not SUPPORTS_WINDOWING: logger.warning( - f"Flash attention is available, but doesn't support windowing which is required by model {model_id} for long contexts." + f"Flash attention is available, but on the backend {SYSTEM} doesn't support sliding window attention, which is required by model {model_id} for long contexts." ) if model_type == MAMBA: @@ -701,7 +707,21 @@ def get_model( if model_type == MISTRAL: sliding_window = config_dict.get("sliding_window", -1) - if FLASH_ATTENTION: + if ( + (sliding_window is not None and sliding_window != -1) + and not SUPPORTS_WINDOWING + and max_batch_prefill_tokens > sliding_window + ): + raise ValueError( + SLIDING_WINDOW_MESSAGE.format( + SYSTEM, + max_batch_prefill_tokens, + sliding_window, + SYSTEM, + sliding_window, + ) + ) + elif FLASH_ATTENTION: return FlashMistral( model_id, revision, @@ -724,7 +744,21 @@ def get_model( if model_type == MIXTRAL: sliding_window = config_dict.get("sliding_window", -1) - if FLASH_ATTENTION: + if ( + (sliding_window is not None and sliding_window != -1) + and not SUPPORTS_WINDOWING + and max_batch_prefill_tokens > sliding_window + ): + raise ValueError( + SLIDING_WINDOW_MESSAGE.format( + SYSTEM, + max_batch_prefill_tokens, + sliding_window, + SYSTEM, + sliding_window, + ) + ) + elif FLASH_ATTENTION: return FlashMixtral( model_id, revision, @@ -746,8 +780,21 @@ def get_model( ) if model_type == STARCODER2: - sliding_window = config_dict.get("sliding_window", -1) - if FLASH_ATTENTION: + if ( + (sliding_window is not None and sliding_window != -1) + and not SUPPORTS_WINDOWING + and max_batch_prefill_tokens > sliding_window + ): + raise ValueError( + SLIDING_WINDOW_MESSAGE.format( + SYSTEM, + max_batch_prefill_tokens, + sliding_window, + SYSTEM, + sliding_window, + ) + ) + elif FLASH_ATTENTION: return FlashStarcoder2( model_id, revision, @@ -771,7 +818,21 @@ def get_model( if model_type == QWEN2: sliding_window = config_dict.get("sliding_window", -1) - if (sliding_window is None or sliding_window != -1) and SUPPORTS_WINDOWING: + if ( + (sliding_window is not None and sliding_window != -1) + and not SUPPORTS_WINDOWING + and max_batch_prefill_tokens > sliding_window + ): + raise ValueError( + SLIDING_WINDOW_MESSAGE.format( + SYSTEM, + max_batch_prefill_tokens, + sliding_window, + SYSTEM, + sliding_window, + ) + ) + elif sliding_window is None or sliding_window != -1: return FlashQwen2( model_id, revision, diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 4118b3f6..51543005 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -199,6 +199,7 @@ def serve( dtype: Optional[str], trust_remote_code: bool, uds_path: Path, + max_batch_prefill_tokens: int, ): async def serve_inner( model_id: str, @@ -229,6 +230,7 @@ def serve( speculate, dtype, trust_remote_code, + max_batch_prefill_tokens, ) except Exception: logger.exception("Error when initializing model")