diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 0acd4ac4..0fe9e6b1 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -476,7 +476,7 @@ fn shard_manager( rope_factor: Option, max_total_tokens: usize, max_batch_size: Option, - max_batch_prefill_tokens: u32, + max_input_tokens: usize, otlp_endpoint: Option, log_level: LevelFilter, status_sender: mpsc::Sender, @@ -550,8 +550,8 @@ fn shard_manager( } // In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter. - shard_args.push("--max-batch-prefill-tokens".to_string()); - shard_args.push(max_batch_prefill_tokens.to_string()); + shard_args.push("--max-input-tokens".to_string()); + shard_args.push(max_input_tokens.to_string()); // Copy current process env let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); @@ -1009,7 +1009,7 @@ fn spawn_shards( args: &Args, cuda_graphs: Vec, max_total_tokens: usize, - max_batch_prefill_tokens: u32, + max_input_tokens: usize, max_log_level: LevelFilter, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, @@ -1067,7 +1067,7 @@ fn spawn_shards( rope_factor, max_total_tokens, max_batch_size, - max_batch_prefill_tokens, + max_input_tokens, otlp_endpoint, max_log_level, status_sender, @@ -1542,7 +1542,7 @@ fn main() -> Result<(), LauncherError> { &args, cuda_graphs, max_total_tokens, - max_batch_prefill_tokens, + max_input_tokens, max_log_level, shutdown.clone(), &shutdown_receiver, diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index b62b57c0..27080d8e 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -41,7 +41,7 @@ def serve( logger_level: str = "INFO", json_output: bool = False, otlp_endpoint: Optional[str] = None, - max_batch_prefill_tokens: Optional[int] = None, + max_input_tokens: Optional[int] = None, ): if sharded: assert ( @@ -98,7 +98,7 @@ def serve( dtype, trust_remote_code, uds_path, - max_batch_prefill_tokens, + max_input_tokens, ) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 2bee4b51..91ed5818 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -169,10 +169,8 @@ if ENGINE == "ck": ): if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") - if window_size_left != -1 and q.shape[0] > window_size_left: - raise ValueError( - f"ROCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." - ) + + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return flash_attn_2_cuda.varlen_fwd( q, k, @@ -204,10 +202,7 @@ elif ENGINE == "triton": window_size_left=-1, causal=True, ): - if window_size_left != -1 and q.shape[0] > window_size_left: - raise ValueError( - f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." - ) + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. output, _ = triton_attention( q, k, diff --git a/server/text_generation_server/layers/attention/xpu.py b/server/text_generation_server/layers/attention/xpu.py index 99c4699b..8b6cb87b 100644 --- a/server/text_generation_server/layers/attention/xpu.py +++ b/server/text_generation_server/layers/attention/xpu.py @@ -14,10 +14,7 @@ def attention( softmax_scale, window_size_left=-1, ): - if window_size_left != -1 and q.shape[0] > window_size_left: - raise ValueError( - f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." - ) + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return ipex.llm.functional.varlen_attention( q, k, diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 3f488206..f39cb4a9 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -50,7 +50,7 @@ __all__ = [ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." -SLIDING_WINDOW_MESSAGE = "The backend {} does not support sliding window attention. TGI webserver was started max_batch_prefill_tokens={} larger than sliding_window={}. To use this model with the {} backend, please launch TGI with the argument `--max-batch-prefill-tokens` smaller than {}." +SLIDING_WINDOW_MESSAGE = "The backend {} does not support sliding window attention. TGI webserver was started max_input_tokens={} larger than sliding_window={}. To use this model with the {} backend, please launch TGI with the argument `--max-batch-prefill-tokens` smaller than {}." FLASH_ATTENTION = True @@ -261,7 +261,7 @@ def get_model( speculate: Optional[int], dtype: Optional[str], trust_remote_code: bool, - max_batch_prefill_tokens: int, + max_input_tokens: int, ) -> Model: global FLASH_ATTENTION if dtype is None: @@ -416,9 +416,13 @@ def get_model( ) sliding_window = config_dict.get("sliding_window", -1) - if sliding_window != -1 and not SUPPORTS_WINDOWING: - logger.warning( - f"Flash attention is available, but on the backend {SYSTEM} doesn't support sliding window attention, which is required by model {model_id} for long contexts." + if ( + (sliding_window is not None and sliding_window != -1) + and not SUPPORTS_WINDOWING + and max_input_tokens > sliding_window + ): + raise ValueError( + f"The backend {SYSTEM} does not support sliding window attention. TGI webserver was started max_input_tokens={max_input_tokens} larger than sliding_window={sliding_window}. To use this model with the {SYSTEM} backend, please launch TGI with the argument `--max-batch-prefill-tokens` smaller than {sliding_window}." ) if model_type == MAMBA: @@ -706,22 +710,7 @@ def get_model( ) if model_type == MISTRAL: - sliding_window = config_dict.get("sliding_window", -1) - if ( - (sliding_window is not None and sliding_window != -1) - and not SUPPORTS_WINDOWING - and max_batch_prefill_tokens > sliding_window - ): - raise ValueError( - SLIDING_WINDOW_MESSAGE.format( - SYSTEM, - max_batch_prefill_tokens, - sliding_window, - SYSTEM, - sliding_window, - ) - ) - elif FLASH_ATTENTION: + if FLASH_ATTENTION: return FlashMistral( model_id, revision, @@ -743,22 +732,7 @@ def get_model( ) if model_type == MIXTRAL: - sliding_window = config_dict.get("sliding_window", -1) - if ( - (sliding_window is not None and sliding_window != -1) - and not SUPPORTS_WINDOWING - and max_batch_prefill_tokens > sliding_window - ): - raise ValueError( - SLIDING_WINDOW_MESSAGE.format( - SYSTEM, - max_batch_prefill_tokens, - sliding_window, - SYSTEM, - sliding_window, - ) - ) - elif FLASH_ATTENTION: + if FLASH_ATTENTION: return FlashMixtral( model_id, revision, @@ -780,21 +754,7 @@ def get_model( ) if model_type == STARCODER2: - if ( - (sliding_window is not None and sliding_window != -1) - and not SUPPORTS_WINDOWING - and max_batch_prefill_tokens > sliding_window - ): - raise ValueError( - SLIDING_WINDOW_MESSAGE.format( - SYSTEM, - max_batch_prefill_tokens, - sliding_window, - SYSTEM, - sliding_window, - ) - ) - elif FLASH_ATTENTION: + if FLASH_ATTENTION: return FlashStarcoder2( model_id, revision, @@ -817,22 +777,7 @@ def get_model( ) if model_type == QWEN2: - sliding_window = config_dict.get("sliding_window", -1) - if ( - (sliding_window is not None and sliding_window != -1) - and not SUPPORTS_WINDOWING - and max_batch_prefill_tokens > sliding_window - ): - raise ValueError( - SLIDING_WINDOW_MESSAGE.format( - SYSTEM, - max_batch_prefill_tokens, - sliding_window, - SYSTEM, - sliding_window, - ) - ) - elif sliding_window is None or sliding_window != -1: + if FLASH_ATTENTION: return FlashQwen2( model_id, revision, diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 51543005..569b6925 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -199,7 +199,7 @@ def serve( dtype: Optional[str], trust_remote_code: bool, uds_path: Path, - max_batch_prefill_tokens: int, + max_input_tokens: int, ): async def serve_inner( model_id: str, @@ -230,7 +230,7 @@ def serve( speculate, dtype, trust_remote_code, - max_batch_prefill_tokens, + max_input_tokens, ) except Exception: logger.exception("Error when initializing model")