From 6951962ffddf91fd8f9b4be20f53ad50310b0c86 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 4 Apr 2024 12:59:29 +0000 Subject: [PATCH] Clarify disabling. --- docs/source/basic_tutorials/launcher.md | 2 +- launcher/src/main.rs | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index ce406ca4..86394ff7 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -209,7 +209,7 @@ Options: ## CUDA_GRAPHS ```shell --cuda-graphs - Specify the batch sizes to compute cuda graphs for + Specify the batch sizes to compute cuda graphs for. Use "0" to disable [env: CUDA_GRAPHS=] [default: 1,2,4,8,16,32,64,96,128] diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 8d50aa81..63676392 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -284,7 +284,8 @@ struct Args { #[clap(long, env)] max_batch_size: Option, - /// Specify the batch sizes to compute cuda graphs for + /// Specify the batch sizes to compute cuda graphs for. + /// Use "0" to disable. #[clap( long, env, @@ -954,7 +955,11 @@ fn spawn_shards( let disable_custom_kernels = args.disable_custom_kernels; let watermark_gamma = args.watermark_gamma; let watermark_delta = args.watermark_delta; - let cuda_graphs = args.cuda_graphs.clone(); + let cuda_graphs: Vec = args + .cuda_graphs + .iter() + .filter_map(|&c| if c > 0 { Some(c) } else { None }) + .collect(); let cuda_memory_fraction = args.cuda_memory_fraction; let rope_scaling = args.rope_scaling; let rope_factor = args.rope_factor;