diff --git a/launcher/src/main.rs b/launcher/src/main.rs index ed6cf8e5..b42ed0c5 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -306,6 +306,14 @@ fn shard_manager( )); }; + // If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard + if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") { + env.push(( + "CUDA_VISIBLE_DEVICES".parse().unwrap(), + cuda_visible_devices.parse().unwrap(), + )); + }; + // Start process tracing::info!("Starting shard {}", rank); let mut p = match Popen::create(