diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 7450b3f4..b59b0cb4 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -4,7 +4,6 @@ use std::env; use std::ffi::OsString; use std::io::{BufRead, BufReader, Read}; use std::path::Path; -use std::process::ExitCode; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc::TryRecvError; use std::sync::Arc; @@ -73,445 +72,6 @@ struct Args { watermark_delta: Option, } -fn main() -> ExitCode { - // Pattern match configuration - let args = Args::parse(); - - if args.json_output { - tracing_subscriber::fmt().json().init(); - } else { - tracing_subscriber::fmt().compact().init(); - } - - tracing::info!("{:?}", args); - - let Args { - model_id, - revision, - sharded, - num_shard, - quantize, - max_concurrent_requests, - max_best_of, - max_stop_sequences, - max_input_length, - max_total_tokens, - max_batch_size, - max_batch_total_tokens, - waiting_served_ratio, - max_waiting_tokens, - port, - shard_uds_path, - master_addr, - master_port, - huggingface_hub_cache, - weights_cache_override, - disable_custom_kernels, - json_output, - otlp_endpoint, - cors_allow_origin, - watermark_gamma, - watermark_delta, - } = args; - - // get the number of shards given `sharded` and `num_shard` - let num_shard = if let Some(sharded) = sharded { - // sharded is set - match sharded { - // sharded is set and true - true => { - match num_shard { - None => { - // try to default to the number of available GPUs - tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES"); - let n_devices = num_cuda_devices() - .expect("--num-shard and CUDA_VISIBLE_DEVICES are not set"); - if n_devices <= 1 { - panic!("`sharded` is true but only found {n_devices} CUDA devices"); - } - n_devices - } - Some(num_shard) => { - // we can't have only one shard while sharded - if num_shard <= 1 { - panic!("`sharded` is true but `num_shard` <= 1"); - } - num_shard - } - } - } - // sharded is set and false - false => { - let num_shard = num_shard.unwrap_or(1); - // we can't have more than one shard while not sharded - if num_shard != 1 { - panic!("`sharded` is false but `num_shard` != 1"); - } - num_shard - } - } - } else { - match num_shard { - // get num_shard from CUDA_VISIBLE_DEVICES or default to a single shard - None => num_cuda_devices().unwrap_or(1), - Some(num_shard) => num_shard, - } - }; - if num_shard < 1 { - panic!("`num_shard` cannot be < 1"); - } - - if num_shard > 1 { - tracing::info!("Sharding model on {num_shard} processes"); - } - - // Signal handler - let running = Arc::new(AtomicBool::new(true)); - let r = running.clone(); - ctrlc::set_handler(move || { - r.store(false, Ordering::SeqCst); - }) - .expect("Error setting Ctrl-C handler"); - - // Check if model_id is a local model - let local_path = Path::new(&model_id); - let is_local_model = local_path.exists() && local_path.is_dir(); - - // Download weights for sharded models - if !is_local_model && weights_cache_override.is_none() && num_shard > 1 { - let mut download_argv = vec![ - "text-generation-server".to_string(), - "download-weights".to_string(), - model_id.clone(), - "--extension".to_string(), - ".safetensors".to_string(), - "--logger-level".to_string(), - "INFO".to_string(), - "--json-output".to_string(), - ]; - - // Model optional revision - if let Some(ref revision) = revision { - download_argv.push("--revision".to_string()); - download_argv.push(revision.to_string()) - } - - // Copy current process env - let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); - - // If huggingface_hub_cache is set, pass it to the shard - // Useful when running inside a docker container - if let Some(ref huggingface_hub_cache) = huggingface_hub_cache { - env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); - }; - - // Enable hf transfer for insane download speeds - let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); - env.push(( - "HF_HUB_ENABLE_HF_TRANSFER".into(), - enable_hf_transfer.into(), - )); - - // Parse Inference API token - if let Ok(api_token) = env::var("HF_API_TOKEN") { - env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) - }; - - // Start process - tracing::info!("Starting download process."); - let mut download_process = match Popen::create( - &download_argv, - PopenConfig { - stdout: Redirection::Pipe, - stderr: Redirection::Pipe, - // Needed for the shutdown procedure - setpgid: true, - env: Some(env), - ..Default::default() - }, - ) { - Ok(p) => p, - Err(err) => { - if let PopenError::IoError(ref err) = err { - if err.kind() == io::ErrorKind::NotFound { - tracing::error!("text-generation-server not found in PATH"); - tracing::error!("Please install it with `make install-server`") - } - } - return ExitCode::FAILURE; - } - }; - - // Redirect STDOUT to the console - let download_stdout = download_process.stdout.take().unwrap(); - thread::spawn(move || { - // Enter download tracing span - let stdout = BufReader::new(download_stdout); - let _span = tracing::span!(tracing::Level::INFO, "download").entered(); - for line in stdout.lines() { - // Parse loguru logs - if let Ok(log) = serde_json::from_str::(&line.unwrap()) { - log.trace(); - } - } - }); - - loop { - if let Some(status) = download_process.poll() { - match status { - ExitStatus::Exited(exit_code) => { - if exit_code == 0 { - tracing::info!("Successfully downloaded weights."); - break; - } else { - let mut err = String::new(); - download_process - .stderr - .take() - .unwrap() - .read_to_string(&mut err) - .unwrap(); - tracing::error!("Download encountered an error: {err}"); - return ExitCode::FAILURE; - } - } - _ => { - tracing::error!("Download process exited with an unknown status."); - return ExitCode::FAILURE; - } - } - } - if !running.load(Ordering::SeqCst) { - download_process.terminate().unwrap(); - tracing::info!("Waiting for download process to gracefully shutdown"); - download_process - .wait_timeout(Duration::from_secs(90)) - .unwrap(); - tracing::info!("Download process terminated"); - return ExitCode::SUCCESS; - } - sleep(Duration::from_millis(100)); - } - } - - // Shared shutdown bool - let shutdown = Arc::new(Mutex::new(false)); - // Shared shutdown channel - // When shutting down, the main thread will wait for all senders to be dropped - let (shutdown_sender, shutdown_receiver) = mpsc::channel(); - - // Shared channel to track shard status - let (status_sender, status_receiver) = mpsc::channel(); - - // Start shard processes - for rank in 0..num_shard { - let model_id = model_id.clone(); - let revision = revision.clone(); - let uds_path = shard_uds_path.clone(); - let master_addr = master_addr.clone(); - let huggingface_hub_cache = huggingface_hub_cache.clone(); - let weights_cache_override = weights_cache_override.clone(); - let status_sender = status_sender.clone(); - let shutdown = shutdown.clone(); - let shutdown_sender = shutdown_sender.clone(); - let otlp_endpoint = otlp_endpoint.clone(); - thread::spawn(move || { - shard_manager( - model_id, - revision, - quantize, - uds_path, - rank, - num_shard, - master_addr, - master_port, - huggingface_hub_cache, - weights_cache_override, - disable_custom_kernels, - watermark_gamma, - watermark_delta, - otlp_endpoint, - status_sender, - shutdown, - shutdown_sender, - ) - }); - } - drop(shutdown_sender); - - // Wait for shard to start - let mut shard_ready = 0; - while running.load(Ordering::SeqCst) { - match status_receiver.try_recv() { - Ok(ShardStatus::Ready) => { - shard_ready += 1; - if shard_ready == num_shard { - break; - } - } - Err(TryRecvError::Empty) => { - sleep(Duration::from_millis(100)); - } - Ok(ShardStatus::Failed((rank, err))) => { - tracing::error!("Shard {} failed to start:\n{}", rank, err); - shutdown_shards(shutdown, &shutdown_receiver); - return ExitCode::FAILURE; - } - Err(TryRecvError::Disconnected) => { - tracing::error!("Shard status channel disconnected"); - shutdown_shards(shutdown, &shutdown_receiver); - return ExitCode::FAILURE; - } - } - } - - // We might have received a termination signal - if !running.load(Ordering::SeqCst) { - shutdown_shards(shutdown, &shutdown_receiver); - return ExitCode::SUCCESS; - } - - // All shard started - // Start webserver - tracing::info!("Starting Webserver"); - let mut argv = vec![ - "text-generation-router".to_string(), - "--max-concurrent-requests".to_string(), - max_concurrent_requests.to_string(), - "--max-best-of".to_string(), - max_best_of.to_string(), - "--max-stop-sequences".to_string(), - max_stop_sequences.to_string(), - "--max-input-length".to_string(), - max_input_length.to_string(), - "--max-total-tokens".to_string(), - max_total_tokens.to_string(), - "--waiting-served-ratio".to_string(), - waiting_served_ratio.to_string(), - "--max-waiting-tokens".to_string(), - max_waiting_tokens.to_string(), - "--port".to_string(), - port.to_string(), - "--master-shard-uds-path".to_string(), - format!("{shard_uds_path}-0"), - "--tokenizer-name".to_string(), - model_id, - ]; - - // Deprecate max_batch_size - if let Some(max_batch_size) = max_batch_size { - argv.push("--max-batch-size".to_string()); - argv.push(max_batch_size.to_string()) - } else { - argv.push("--max-batch-total-tokens".to_string()); - argv.push(max_batch_total_tokens.to_string()) - } - - // Model optional revision - if let Some(ref revision) = revision { - argv.push("--revision".to_string()); - argv.push(revision.to_string()) - } - - if json_output { - argv.push("--json-output".to_string()); - } - - // OpenTelemetry - if let Some(otlp_endpoint) = otlp_endpoint { - argv.push("--otlp-endpoint".to_string()); - argv.push(otlp_endpoint); - } - - // CORS origins - for origin in cors_allow_origin.into_iter() { - argv.push("--cors-allow-origin".to_string()); - argv.push(origin); - } - - // Copy current process env - let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); - - // Parse Inference API token - if let Ok(api_token) = env::var("HF_API_TOKEN") { - env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) - }; - - let mut webserver = match Popen::create( - &argv, - PopenConfig { - stdout: Redirection::Pipe, - stderr: Redirection::Pipe, - // Needed for the shutdown procedure - setpgid: true, - env: Some(env), - ..Default::default() - }, - ) { - Ok(p) => p, - Err(err) => { - tracing::error!("Failed to start webserver: {}", err); - if let PopenError::IoError(err) = err { - if err.kind() == io::ErrorKind::NotFound { - tracing::error!("text-generation-router not found in PATH"); - tracing::error!("Please install it with `make install-router`") - } - } else { - tracing::error!("{}", err); - } - - shutdown_shards(shutdown, &shutdown_receiver); - return ExitCode::FAILURE; - } - }; - - // Redirect STDOUT and STDERR to the console - let webserver_stdout = webserver.stdout.take().unwrap(); - let webserver_stderr = webserver.stderr.take().unwrap(); - - thread::spawn(move || { - let stdout = BufReader::new(webserver_stdout); - let stderr = BufReader::new(webserver_stderr); - for line in stdout.lines() { - println!("{}", line.unwrap()); - } - for line in stderr.lines() { - println!("{}", line.unwrap()); - } - }); - - // Default exit code - let mut exit_code = ExitCode::SUCCESS; - - while running.load(Ordering::SeqCst) { - if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() { - tracing::error!("Shard {rank} failed:\n{err}"); - exit_code = ExitCode::FAILURE; - break; - }; - - match webserver.poll() { - Some(_) => { - tracing::error!("Webserver Crashed"); - shutdown_shards(shutdown, &shutdown_receiver); - return ExitCode::FAILURE; - } - None => { - sleep(Duration::from_millis(100)); - } - }; - } - - // Graceful termination - webserver.terminate().unwrap(); - tracing::info!("Waiting for webserver to gracefully shutdown"); - webserver.wait_timeout(Duration::from_secs(90)).unwrap(); - tracing::info!("Webserver terminated"); - shutdown_shards(shutdown, &shutdown_receiver); - - exit_code -} - #[derive(Debug)] enum ShardStatus { Ready, @@ -774,3 +334,450 @@ impl PythonLogMessage { } } } + +fn find_num_shards(sharded: Option, num_shard: Option) -> usize { + // get the number of shards given `sharded` and `num_shard` + let num_shard = match (sharded, num_shard) { + (Some(true), None) => { + // try to default to the number of available GPUs + tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES"); + let n_devices = + num_cuda_devices().expect("--num-shard and CUDA_VISIBLE_DEVICES are not set"); + if n_devices <= 1 { + panic!("`sharded` is true but only found {n_devices} CUDA devices"); + } + n_devices + } + (Some(true), Some(num_shard)) => { + // we can't have only one shard while sharded + if num_shard <= 1 { + panic!("`sharded` is true but `num_shard` <= 1"); + } + num_shard + } + (Some(false), Some(num_shard)) => num_shard, + (Some(false), None) => 1, + (None, None) => num_cuda_devices().unwrap_or(1), + (None, Some(num_shard)) => num_shard, + }; + if num_shard < 1 { + panic!("`num_shard` cannot be < 1"); + } + num_shard +} + +#[derive(Debug)] +enum LauncherError { + DownloadError, + ShardCannotStart, + ShardDisconnected, + ShardFailed, + WebserverFailed, + WebserverCannotStart, +} + +fn download_model(args: &Args, running: Arc) -> Result<(), LauncherError> { + let mut download_argv = vec![ + "text-generation-server".to_string(), + "download-weights".to_string(), + args.model_id.to_string(), + "--extension".to_string(), + ".safetensors".to_string(), + "--logger-level".to_string(), + "INFO".to_string(), + "--json-output".to_string(), + ]; + + // Model optional revision + if let Some(revision) = &args.revision { + download_argv.push("--revision".to_string()); + download_argv.push(revision.to_string()) + } + + // Copy current process env + let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); + + // If huggingface_hub_cache is set, pass it to the shard + // Useful when running inside a docker container + if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache { + env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); + }; + + // Enable hf transfer for insane download speeds + let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); + env.push(( + "HF_HUB_ENABLE_HF_TRANSFER".into(), + enable_hf_transfer.into(), + )); + + // Parse Inference API token + if let Ok(api_token) = env::var("HF_API_TOKEN") { + env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) + }; + + // Start process + tracing::info!("Starting download process."); + let mut download_process = match Popen::create( + &download_argv, + PopenConfig { + stdout: Redirection::Pipe, + stderr: Redirection::Pipe, + // Needed for the shutdown procedure + setpgid: true, + env: Some(env), + ..Default::default() + }, + ) { + Ok(p) => p, + Err(err) => { + if let PopenError::IoError(ref err) = err { + if err.kind() == io::ErrorKind::NotFound { + tracing::error!("text-generation-server not found in PATH"); + tracing::error!("Please install it with `make install-server`") + } + } + return Err(LauncherError::DownloadError); + } + }; + + // Redirect STDOUT to the console + let download_stdout = download_process.stdout.take().unwrap(); + thread::spawn(move || { + // Enter download tracing span + let stdout = BufReader::new(download_stdout); + let _span = tracing::span!(tracing::Level::INFO, "download").entered(); + for line in stdout.lines() { + // Parse loguru logs + if let Ok(log) = serde_json::from_str::(&line.unwrap()) { + log.trace(); + } + } + }); + + loop { + if let Some(status) = download_process.poll() { + match status { + ExitStatus::Exited(exit_code) => { + if exit_code == 0 { + tracing::info!("Successfully downloaded weights."); + break; + } else { + let mut err = String::new(); + download_process + .stderr + .take() + .unwrap() + .read_to_string(&mut err) + .unwrap(); + tracing::error!("Download encountered an error: {err}"); + return Err(LauncherError::DownloadError); + } + } + _ => { + tracing::error!("Download process exited with an unknown status."); + return Err(LauncherError::DownloadError); + } + } + } + if !running.load(Ordering::SeqCst) { + download_process.terminate().unwrap(); + tracing::info!("Waiting for download process to gracefully shutdown"); + download_process + .wait_timeout(Duration::from_secs(90)) + .unwrap(); + tracing::info!("Download process terminated"); + return Ok(()); + } + sleep(Duration::from_millis(100)); + } + Ok(()) +} + +fn spawn_shards( + num_shard: usize, + args: &Args, + shutdown: Arc>, + shutdown_receiver: &mpsc::Receiver<()>, + shutdown_sender: mpsc::Sender<()>, + status_receiver: &mpsc::Receiver, + status_sender: mpsc::Sender, + running: Arc, +) -> Result<(), LauncherError> { + // Start shard processes + for rank in 0..num_shard { + let model_id = args.model_id.clone(); + let revision = args.revision.clone(); + let uds_path = args.shard_uds_path.clone(); + let master_addr = args.master_addr.clone(); + let huggingface_hub_cache = args.huggingface_hub_cache.clone(); + let weights_cache_override = args.weights_cache_override.clone(); + let status_sender = status_sender.clone(); + let shutdown = shutdown.clone(); + let shutdown_sender = shutdown_sender.clone(); + let otlp_endpoint = args.otlp_endpoint.clone(); + let quantize = args.quantize.clone(); + let master_port = args.master_port.clone(); + let disable_custom_kernels = args.disable_custom_kernels.clone(); + let watermark_gamma = args.watermark_gamma.clone(); + let watermark_delta = args.watermark_delta.clone(); + thread::spawn(move || { + shard_manager( + model_id, + revision, + quantize, + uds_path, + rank, + num_shard, + master_addr, + master_port, + huggingface_hub_cache, + weights_cache_override, + disable_custom_kernels, + watermark_gamma, + watermark_delta, + otlp_endpoint, + status_sender, + shutdown, + shutdown_sender, + ) + }); + } + drop(shutdown_sender); + + // Wait for shard to start + let mut shard_ready = 0; + while running.load(Ordering::SeqCst) { + match status_receiver.try_recv() { + Ok(ShardStatus::Ready) => { + shard_ready += 1; + if shard_ready == num_shard { + break; + } + } + Err(TryRecvError::Empty) => { + sleep(Duration::from_millis(100)); + } + Ok(ShardStatus::Failed((rank, err))) => { + tracing::error!("Shard {} failed to start:\n{}", rank, err); + shutdown_shards(shutdown, &shutdown_receiver); + return Err(LauncherError::ShardCannotStart); + } + Err(TryRecvError::Disconnected) => { + tracing::error!("Shard status channel disconnected"); + shutdown_shards(shutdown, &shutdown_receiver); + return Err(LauncherError::ShardDisconnected); + } + } + } + Ok(()) +} + +fn spawn_webserver( + args: Args, + shutdown: Arc>, + shutdown_receiver: &mpsc::Receiver<()>, +) -> Result { + // All shard started + // Start webserver + tracing::info!("Starting Webserver"); + let mut argv = vec![ + "text-generation-router".to_string(), + "--max-concurrent-requests".to_string(), + args.max_concurrent_requests.to_string(), + "--max-best-of".to_string(), + args.max_best_of.to_string(), + "--max-stop-sequences".to_string(), + args.max_stop_sequences.to_string(), + "--max-input-length".to_string(), + args.max_input_length.to_string(), + "--max-total-tokens".to_string(), + args.max_total_tokens.to_string(), + "--waiting-served-ratio".to_string(), + args.waiting_served_ratio.to_string(), + "--max-waiting-tokens".to_string(), + args.max_waiting_tokens.to_string(), + "--port".to_string(), + args.port.to_string(), + "--master-shard-uds-path".to_string(), + format!("{}-0", args.shard_uds_path), + "--tokenizer-name".to_string(), + args.model_id, + ]; + + // Deprecate max_batch_size + if let Some(max_batch_size) = args.max_batch_size { + argv.push("--max-batch-size".to_string()); + argv.push(max_batch_size.to_string()) + } else { + argv.push("--max-batch-total-tokens".to_string()); + argv.push(args.max_batch_total_tokens.to_string()) + } + + // Model optional revision + if let Some(ref revision) = args.revision { + argv.push("--revision".to_string()); + argv.push(revision.to_string()) + } + + if args.json_output { + argv.push("--json-output".to_string()); + } + + // OpenTelemetry + if let Some(otlp_endpoint) = args.otlp_endpoint { + argv.push("--otlp-endpoint".to_string()); + argv.push(otlp_endpoint); + } + + // CORS origins + for origin in args.cors_allow_origin.into_iter() { + argv.push("--cors-allow-origin".to_string()); + argv.push(origin); + } + + // Copy current process env + let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); + + // Parse Inference API token + if let Ok(api_token) = env::var("HF_API_TOKEN") { + env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) + }; + + let mut webserver = match Popen::create( + &argv, + PopenConfig { + stdout: Redirection::Pipe, + stderr: Redirection::Pipe, + // Needed for the shutdown procedure + setpgid: true, + env: Some(env), + ..Default::default() + }, + ) { + Ok(p) => p, + Err(err) => { + tracing::error!("Failed to start webserver: {}", err); + if let PopenError::IoError(err) = err { + if err.kind() == io::ErrorKind::NotFound { + tracing::error!("text-generation-router not found in PATH"); + tracing::error!("Please install it with `make install-router`") + } + } else { + tracing::error!("{}", err); + } + + shutdown_shards(shutdown, &shutdown_receiver); + return Err(LauncherError::WebserverCannotStart); + } + }; + + // Redirect STDOUT and STDERR to the console + let webserver_stdout = webserver.stdout.take().unwrap(); + let webserver_stderr = webserver.stderr.take().unwrap(); + + thread::spawn(move || { + let stdout = BufReader::new(webserver_stdout); + let stderr = BufReader::new(webserver_stderr); + for line in stdout.lines() { + println!("{}", line.unwrap()); + } + for line in stderr.lines() { + println!("{}", line.unwrap()); + } + }); + Ok(webserver) +} + +fn main() -> Result<(), LauncherError> { + // Pattern match configuration + let args = Args::parse(); + + if args.json_output { + tracing_subscriber::fmt().json().init(); + } else { + tracing_subscriber::fmt().compact().init(); + } + + tracing::info!("{:?}", args); + + let num_shard = find_num_shards(args.sharded, args.num_shard); + if num_shard > 1 { + tracing::info!("Sharding model on {num_shard} processes"); + } + + // Signal handler + let running = Arc::new(AtomicBool::new(true)); + let r = running.clone(); + ctrlc::set_handler(move || { + r.store(false, Ordering::SeqCst); + }) + .expect("Error setting Ctrl-C handler"); + + // Check if model_id is a local model + let local_path = Path::new(&args.model_id); + let is_local_model = local_path.exists() && local_path.is_dir(); + + // Download weights for sharded models + if !is_local_model && args.weights_cache_override.is_none() && num_shard > 1 { + download_model(&args, running.clone())?; + } + + // Shared shutdown bool + let shutdown = Arc::new(Mutex::new(false)); + // Shared shutdown channel + // When shutting down, the main thread will wait for all senders to be dropped + let (shutdown_sender, shutdown_receiver) = mpsc::channel(); + + // Shared channel to track shard status + let (status_sender, status_receiver) = mpsc::channel(); + + spawn_shards( + num_shard, + &args, + shutdown.clone(), + &shutdown_receiver, + shutdown_sender, + &status_receiver, + status_sender, + running.clone(), + )?; + + // We might have received a termination signal + if !running.load(Ordering::SeqCst) { + shutdown_shards(shutdown, &shutdown_receiver); + return Ok(()); + } + + let mut webserver = spawn_webserver(args, shutdown.clone(), &shutdown_receiver)?; + + // Default exit code + let mut exit_code = Ok(()); + + while running.load(Ordering::SeqCst) { + if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() { + tracing::error!("Shard {rank} failed:\n{err}"); + exit_code = Err(LauncherError::ShardFailed); + break; + }; + + match webserver.poll() { + Some(_) => { + tracing::error!("Webserver Crashed"); + shutdown_shards(shutdown, &shutdown_receiver); + return Err(LauncherError::WebserverFailed); + } + None => { + sleep(Duration::from_millis(100)); + } + }; + } + + // Graceful termination + webserver.terminate().unwrap(); + tracing::info!("Waiting for webserver to gracefully shutdown"); + webserver.wait_timeout(Duration::from_secs(90)).unwrap(); + tracing::info!("Webserver terminated"); + shutdown_shards(shutdown, &shutdown_receiver); + + exit_code +}