From 146e72c3be4f5bfe88b9baae4c2ff6a2ce41eda9 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 30 May 2023 12:52:18 +0200 Subject: [PATCH] fix(launcher): parse num cuda devices from CUDA_VISIBLE_DEVICES and NVIDIA_VISIBLE_DEVICES --- launcher/src/main.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index dc12c90f..a863ea4a 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -455,11 +455,11 @@ fn shutdown_shards(shutdown: Arc>, shutdown_receiver: &mpsc::Receive } fn num_cuda_devices() -> Option { - if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") { - let n_devices = cuda_visible_devices.split(',').count(); - return Some(n_devices); - } - None + let devices = env::var("CUDA_VISIBLE_DEVICES") + .map_err(|_| env::var("NVIDIA_VISIBLE_DEVICES")) + .ok()?; + let n_devices = devices.split(',').count(); + Some(n_devices) } #[derive(Deserialize)] @@ -509,9 +509,9 @@ fn find_num_shards(sharded: Option, num_shard: Option) -> usize { let num_shard = match (sharded, num_shard) { (Some(true), None) => { // try to default to the number of available GPUs - tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES"); - let n_devices = - num_cuda_devices().expect("--num-shard and CUDA_VISIBLE_DEVICES are not set"); + tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES"); + let n_devices = num_cuda_devices() + .expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set"); if n_devices <= 1 { panic!("`sharded` is true but only found {n_devices} CUDA devices"); }