diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index f9b76ed15..627aff933 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -168,7 +168,7 @@ Options: ## MAX_BATCH_PREFILL_TOKENS ```shell --max-batch-prefill-tokens - Limits the number of tokens for the prefill operation. Since this operation take the most memory and is compute bound, it is interesting to limit the number of requests that can be sent + Limits the number of tokens for the prefill operation. Since this operation take the most memory and is compute bound, it is interesting to limit the number of requests that can be sent. Default to min(max_input_length + 50, 16384) to give a bit of room [env: MAX_BATCH_PREFILL_TOKENS=] @@ -215,10 +215,9 @@ Options: ## CUDA_GRAPHS ```shell --cuda-graphs - Specify the batch sizes to compute cuda graphs for. Use "0" to disable + Specify the batch sizes to compute cuda graphs for. Use "0" to disable. Default = "1,2,4,8,16,32" [env: CUDA_GRAPHS=] - [default: 1,2,4,8,16,32,64,96,128] ``` ## HOSTNAME diff --git a/launcher/src/main.rs b/launcher/src/main.rs index d0f9e3cf1..a4bb2e3d7 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -256,6 +256,7 @@ struct Args { /// Limits the number of tokens for the prefill operation. /// Since this operation take the most memory and is compute bound, it is interesting /// to limit the number of requests that can be sent. + /// Default to min(max_input_length + 50, 16384) to give a bit of room. #[clap(long, env)] max_batch_prefill_tokens: Option, @@ -306,13 +307,9 @@ struct Args { /// Specify the batch sizes to compute cuda graphs for. /// Use "0" to disable. - #[clap( - long, - env, - value_delimiter = ',', - default_value = "1,2,4,8,16,32,64,96,128" - )] - cuda_graphs: Vec, + /// Default = "1,2,4,8,16,32" + #[clap(long, env, value_delimiter = ',')] + cuda_graphs: Option>, /// The IP address to listen on #[clap(default_value = "0.0.0.0", long, env)] @@ -956,6 +953,7 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L fn spawn_shards( num_shard: usize, args: &Args, + cuda_graphs: Vec, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, shutdown_sender: mpsc::Sender<()>, @@ -983,11 +981,7 @@ fn spawn_shards( let disable_custom_kernels = args.disable_custom_kernels; let watermark_gamma = args.watermark_gamma; let watermark_delta = args.watermark_delta; - let cuda_graphs: Vec = args - .cuda_graphs - .iter() - .filter_map(|&c| if c > 0 { Some(c) } else { None }) - .collect(); + let cuda_graphs_clone = cuda_graphs.clone(); let cuda_memory_fraction = args.cuda_memory_fraction; let rope_scaling = args.rope_scaling; let rope_factor = args.rope_factor; @@ -1009,7 +1003,7 @@ fn spawn_shards( disable_custom_kernels, watermark_gamma, watermark_delta, - cuda_graphs, + cuda_graphs_clone, cuda_memory_fraction, rope_scaling, rope_factor, @@ -1363,6 +1357,27 @@ fn main() -> Result<(), LauncherError> { ))); } + let cuda_graphs = match (&args.cuda_graphs, &args.quantize) { + (Some(cuda_graphs), Some(_q)) => cuda_graphs.clone(), + #[allow(deprecated)] + ( + None, + Some( + Quantization::Bitsandbytes + | Quantization::BitsandbytesNF4 + | Quantization::BitsandbytesFP4, + ), + ) => { + tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them"); + vec![] + } + _ => { + let cuda_graphs = vec![1, 2, 4, 8, 16, 32]; + tracing::info!("Using default cuda graphs {cuda_graphs:?}"); + cuda_graphs + } + }; + if args.validation_workers == 0 { return Err(LauncherError::ArgumentValidation( "`validation_workers` must be > 0".to_string(), @@ -1437,6 +1452,7 @@ fn main() -> Result<(), LauncherError> { spawn_shards( num_shard, &args, + cuda_graphs, shutdown.clone(), &shutdown_receiver, shutdown_sender,