fix: count gpu uuids if NVIDIA_VISIBLE_DEVICES env set to all (#3230)

This commit is contained in:
drbh 2025-05-16 11:48:58 -04:00 committed by GitHub
parent 18cbecfb38
commit 58934c8b61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1263,7 +1263,23 @@ fn num_cuda_devices() -> Option<usize> {
let devices = match env::var("CUDA_VISIBLE_DEVICES") {
Ok(devices) => devices,
Err(_) => match env::var("NVIDIA_VISIBLE_DEVICES") {
Ok(devices) => devices,
Ok(devices) => {
if devices.trim() == "all" {
// Count the number of all GPUs via nvidia-smi
let output = Command::new("nvidia-smi")
.args(["--query-gpu=uuid", "--format=csv,noheader"])
.output()
.ok()?;
String::from_utf8_lossy(&output.stdout)
.lines()
.filter(|line| !line.trim().is_empty())
.count()
.to_string()
} else {
devices
}
}
Err(_) => env::var("ZE_AFFINITY_MASK").ok()?,
},
};