diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 96ad18f9..1865cf90 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -115,13 +115,11 @@ fn main() -> ExitCode { None => { // try to default to the number of available GPUs tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES"); - let cuda_visible_devices = env::var("CUDA_VISIBLE_DEVICES") + let n_devices = num_cuda_devices() .expect("--num-shard and CUDA_VISIBLE_DEVICES are not set"); - let n_devices = cuda_visible_devices.split(",").count(); if n_devices <= 1 { panic!("`sharded` is true but only found {n_devices} CUDA devices"); } - tracing::info!("Sharding on {n_devices} found CUDA devices"); n_devices } Some(num_shard) => { @@ -144,9 +142,19 @@ fn main() -> ExitCode { } } } else { - // default to a single shard - num_shard.unwrap_or(1) + match num_shard { + // get num_shard from CUDA_VISIBLE_DEVICES or default to a single shard + None => num_cuda_devices().unwrap_or(1), + Some(num_shard) => num_shard, + } }; + if num_shard < 1 { + panic!("`num_shard` cannot be < 1"); + } + + if num_shard > 1 { + tracing::info!("Sharding model on {num_shard} processes"); + } // Signal handler let running = Arc::new(AtomicBool::new(true)); @@ -669,3 +677,11 @@ fn shutdown_shards(shutdown: Arc>, shutdown_receiver: &mpsc::Receive // This will block till all shutdown_sender are dropped let _ = shutdown_receiver.recv(); } + +fn num_cuda_devices() -> Option { + if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") { + let n_devices = cuda_visible_devices.split(',').count(); + return Some(n_devices); + } + None +}