diff --git a/launcher/src/main.rs b/launcher/src/main.rs index a82ad12f..ee80eb00 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1263,7 +1263,23 @@ fn num_cuda_devices() -> Option { let devices = match env::var("CUDA_VISIBLE_DEVICES") { Ok(devices) => devices, Err(_) => match env::var("NVIDIA_VISIBLE_DEVICES") { - Ok(devices) => devices, + Ok(devices) => { + if devices.trim() == "all" { + // Count the number of all GPUs via nvidia-smi + let output = Command::new("nvidia-smi") + .args(["--query-gpu=uuid", "--format=csv,noheader"]) + .output() + .ok()?; + + String::from_utf8_lossy(&output.stdout) + .lines() + .filter(|line| !line.trim().is_empty()) + .count() + .to_string() + } else { + devices + } + } Err(_) => env::var("ZE_AFFINITY_MASK").ok()?, }, };