From 390ec5aea80d55e6b1f4aa8b77e59262b9a48d93 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 25 Apr 2023 20:52:37 +0200 Subject: [PATCH] Cleanup up a bit the launcher. --- launcher/src/main.rs | 128 +++++++++++++++++++++++-------------------- 1 file changed, 69 insertions(+), 59 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 3cd925d4..b59b0cb4 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -572,65 +572,11 @@ fn spawn_shards( Ok(()) } -fn main() -> Result<(), LauncherError> { - // Pattern match configuration - let args = Args::parse(); - - if args.json_output { - tracing_subscriber::fmt().json().init(); - } else { - tracing_subscriber::fmt().compact().init(); - } - - tracing::info!("{:?}", args); - - let num_shard = find_num_shards(args.sharded, args.num_shard); - if num_shard > 1 { - tracing::info!("Sharding model on {num_shard} processes"); - } - - // Signal handler - let running = Arc::new(AtomicBool::new(true)); - let r = running.clone(); - ctrlc::set_handler(move || { - r.store(false, Ordering::SeqCst); - }) - .expect("Error setting Ctrl-C handler"); - - // Check if model_id is a local model - let local_path = Path::new(&args.model_id); - let is_local_model = local_path.exists() && local_path.is_dir(); - - // Download weights for sharded models - if !is_local_model && args.weights_cache_override.is_none() && num_shard > 1 { - download_model(&args, running.clone())?; - } - - // Shared shutdown bool - let shutdown = Arc::new(Mutex::new(false)); - // Shared shutdown channel - // When shutting down, the main thread will wait for all senders to be dropped - let (shutdown_sender, shutdown_receiver) = mpsc::channel(); - - // Shared channel to track shard status - let (status_sender, status_receiver) = mpsc::channel(); - - spawn_shards( - num_shard, - &args, - shutdown.clone(), - &shutdown_receiver, - shutdown_sender, - &status_receiver, - status_sender, - running.clone(), - )?; - // We might have received a termination signal - if !running.load(Ordering::SeqCst) { - shutdown_shards(shutdown, &shutdown_receiver); - return Ok(()); - } - +fn spawn_webserver( + args: Args, + shutdown: Arc>, + shutdown_receiver: &mpsc::Receiver<()>, +) -> Result { // All shard started // Start webserver tracing::info!("Starting Webserver"); @@ -739,6 +685,70 @@ fn main() -> Result<(), LauncherError> { println!("{}", line.unwrap()); } }); + Ok(webserver) +} + +fn main() -> Result<(), LauncherError> { + // Pattern match configuration + let args = Args::parse(); + + if args.json_output { + tracing_subscriber::fmt().json().init(); + } else { + tracing_subscriber::fmt().compact().init(); + } + + tracing::info!("{:?}", args); + + let num_shard = find_num_shards(args.sharded, args.num_shard); + if num_shard > 1 { + tracing::info!("Sharding model on {num_shard} processes"); + } + + // Signal handler + let running = Arc::new(AtomicBool::new(true)); + let r = running.clone(); + ctrlc::set_handler(move || { + r.store(false, Ordering::SeqCst); + }) + .expect("Error setting Ctrl-C handler"); + + // Check if model_id is a local model + let local_path = Path::new(&args.model_id); + let is_local_model = local_path.exists() && local_path.is_dir(); + + // Download weights for sharded models + if !is_local_model && args.weights_cache_override.is_none() && num_shard > 1 { + download_model(&args, running.clone())?; + } + + // Shared shutdown bool + let shutdown = Arc::new(Mutex::new(false)); + // Shared shutdown channel + // When shutting down, the main thread will wait for all senders to be dropped + let (shutdown_sender, shutdown_receiver) = mpsc::channel(); + + // Shared channel to track shard status + let (status_sender, status_receiver) = mpsc::channel(); + + spawn_shards( + num_shard, + &args, + shutdown.clone(), + &shutdown_receiver, + shutdown_sender, + &status_receiver, + status_sender, + running.clone(), + )?; + + // We might have received a termination signal + if !running.load(Ordering::SeqCst) { + shutdown_shards(shutdown, &shutdown_receiver); + return Ok(()); + } + + let mut webserver = spawn_webserver(args, shutdown.clone(), &shutdown_receiver)?; // Default exit code let mut exit_code = Ok(());