This commit is contained in:
OlivierDehaene 2023-03-08 13:03:19 +01:00
parent b761d02713
commit 466963238c

View File

@ -115,7 +115,8 @@ fn main() -> ExitCode {
None => { None => {
// try to default to the number of available GPUs // try to default to the number of available GPUs
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES"); tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES");
let n_devices = num_cuda_devices().expect("--num-shard and CUDA_VISIBLE_DEVICES are not set"); let n_devices = num_cuda_devices()
.expect("--num-shard and CUDA_VISIBLE_DEVICES are not set");
if n_devices <= 1 { if n_devices <= 1 {
panic!("`sharded` is true but only found {n_devices} CUDA devices"); panic!("`sharded` is true but only found {n_devices} CUDA devices");
} }
@ -144,7 +145,7 @@ fn main() -> ExitCode {
match num_shard { match num_shard {
// get num_shard from CUDA_VISIBLE_DEVICES or default to a single shard // get num_shard from CUDA_VISIBLE_DEVICES or default to a single shard
None => num_cuda_devices().unwrap_or(1), None => num_cuda_devices().unwrap_or(1),
Some(num_shard) => num_shard Some(num_shard) => num_shard,
} }
}; };
tracing::info!("Sharding model on {num_shard} processes"); tracing::info!("Sharding model on {num_shard} processes");
@ -155,7 +156,7 @@ fn main() -> ExitCode {
ctrlc::set_handler(move || { ctrlc::set_handler(move || {
r.store(false, Ordering::SeqCst); r.store(false, Ordering::SeqCst);
}) })
.expect("Error setting Ctrl-C handler"); .expect("Error setting Ctrl-C handler");
// Check if model_id is a local model // Check if model_id is a local model
let local_path = Path::new(&model_id); let local_path = Path::new(&model_id);
@ -671,10 +672,9 @@ fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receive
let _ = shutdown_receiver.recv(); let _ = shutdown_receiver.recv();
} }
fn num_cuda_devices() -> Option<usize> { fn num_cuda_devices() -> Option<usize> {
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") { if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
let n_devices = cuda_visible_devices.split(",").count(); let n_devices = cuda_visible_devices.split(',').count();
return Some(n_devices); return Some(n_devices);
} }
None None