From 77758f603b87a15dec153ca4d5b6bfb6832f1cff Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 26 Apr 2023 14:43:36 +0200 Subject: [PATCH 1/2] chore(launcher): refactor logic (#242) Hopefully it's cleaner --- launcher/src/main.rs | 887 ++++++++++++++++++++++--------------------- 1 file changed, 447 insertions(+), 440 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 7450b3f4..b59b0cb4 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -4,7 +4,6 @@ use std::env; use std::ffi::OsString; use std::io::{BufRead, BufReader, Read}; use std::path::Path; -use std::process::ExitCode; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc::TryRecvError; use std::sync::Arc; @@ -73,445 +72,6 @@ struct Args { watermark_delta: Option, } -fn main() -> ExitCode { - // Pattern match configuration - let args = Args::parse(); - - if args.json_output { - tracing_subscriber::fmt().json().init(); - } else { - tracing_subscriber::fmt().compact().init(); - } - - tracing::info!("{:?}", args); - - let Args { - model_id, - revision, - sharded, - num_shard, - quantize, - max_concurrent_requests, - max_best_of, - max_stop_sequences, - max_input_length, - max_total_tokens, - max_batch_size, - max_batch_total_tokens, - waiting_served_ratio, - max_waiting_tokens, - port, - shard_uds_path, - master_addr, - master_port, - huggingface_hub_cache, - weights_cache_override, - disable_custom_kernels, - json_output, - otlp_endpoint, - cors_allow_origin, - watermark_gamma, - watermark_delta, - } = args; - - // get the number of shards given `sharded` and `num_shard` - let num_shard = if let Some(sharded) = sharded { - // sharded is set - match sharded { - // sharded is set and true - true => { - match num_shard { - None => { - // try to default to the number of available GPUs - tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES"); - let n_devices = num_cuda_devices() - .expect("--num-shard and CUDA_VISIBLE_DEVICES are not set"); - if n_devices <= 1 { - panic!("`sharded` is true but only found {n_devices} CUDA devices"); - } - n_devices - } - Some(num_shard) => { - // we can't have only one shard while sharded - if num_shard <= 1 { - panic!("`sharded` is true but `num_shard` <= 1"); - } - num_shard - } - } - } - // sharded is set and false - false => { - let num_shard = num_shard.unwrap_or(1); - // we can't have more than one shard while not sharded - if num_shard != 1 { - panic!("`sharded` is false but `num_shard` != 1"); - } - num_shard - } - } - } else { - match num_shard { - // get num_shard from CUDA_VISIBLE_DEVICES or default to a single shard - None => num_cuda_devices().unwrap_or(1), - Some(num_shard) => num_shard, - } - }; - if num_shard < 1 { - panic!("`num_shard` cannot be < 1"); - } - - if num_shard > 1 { - tracing::info!("Sharding model on {num_shard} processes"); - } - - // Signal handler - let running = Arc::new(AtomicBool::new(true)); - let r = running.clone(); - ctrlc::set_handler(move || { - r.store(false, Ordering::SeqCst); - }) - .expect("Error setting Ctrl-C handler"); - - // Check if model_id is a local model - let local_path = Path::new(&model_id); - let is_local_model = local_path.exists() && local_path.is_dir(); - - // Download weights for sharded models - if !is_local_model && weights_cache_override.is_none() && num_shard > 1 { - let mut download_argv = vec![ - "text-generation-server".to_string(), - "download-weights".to_string(), - model_id.clone(), - "--extension".to_string(), - ".safetensors".to_string(), - "--logger-level".to_string(), - "INFO".to_string(), - "--json-output".to_string(), - ]; - - // Model optional revision - if let Some(ref revision) = revision { - download_argv.push("--revision".to_string()); - download_argv.push(revision.to_string()) - } - - // Copy current process env - let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); - - // If huggingface_hub_cache is set, pass it to the shard - // Useful when running inside a docker container - if let Some(ref huggingface_hub_cache) = huggingface_hub_cache { - env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); - }; - - // Enable hf transfer for insane download speeds - let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); - env.push(( - "HF_HUB_ENABLE_HF_TRANSFER".into(), - enable_hf_transfer.into(), - )); - - // Parse Inference API token - if let Ok(api_token) = env::var("HF_API_TOKEN") { - env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) - }; - - // Start process - tracing::info!("Starting download process."); - let mut download_process = match Popen::create( - &download_argv, - PopenConfig { - stdout: Redirection::Pipe, - stderr: Redirection::Pipe, - // Needed for the shutdown procedure - setpgid: true, - env: Some(env), - ..Default::default() - }, - ) { - Ok(p) => p, - Err(err) => { - if let PopenError::IoError(ref err) = err { - if err.kind() == io::ErrorKind::NotFound { - tracing::error!("text-generation-server not found in PATH"); - tracing::error!("Please install it with `make install-server`") - } - } - return ExitCode::FAILURE; - } - }; - - // Redirect STDOUT to the console - let download_stdout = download_process.stdout.take().unwrap(); - thread::spawn(move || { - // Enter download tracing span - let stdout = BufReader::new(download_stdout); - let _span = tracing::span!(tracing::Level::INFO, "download").entered(); - for line in stdout.lines() { - // Parse loguru logs - if let Ok(log) = serde_json::from_str::(&line.unwrap()) { - log.trace(); - } - } - }); - - loop { - if let Some(status) = download_process.poll() { - match status { - ExitStatus::Exited(exit_code) => { - if exit_code == 0 { - tracing::info!("Successfully downloaded weights."); - break; - } else { - let mut err = String::new(); - download_process - .stderr - .take() - .unwrap() - .read_to_string(&mut err) - .unwrap(); - tracing::error!("Download encountered an error: {err}"); - return ExitCode::FAILURE; - } - } - _ => { - tracing::error!("Download process exited with an unknown status."); - return ExitCode::FAILURE; - } - } - } - if !running.load(Ordering::SeqCst) { - download_process.terminate().unwrap(); - tracing::info!("Waiting for download process to gracefully shutdown"); - download_process - .wait_timeout(Duration::from_secs(90)) - .unwrap(); - tracing::info!("Download process terminated"); - return ExitCode::SUCCESS; - } - sleep(Duration::from_millis(100)); - } - } - - // Shared shutdown bool - let shutdown = Arc::new(Mutex::new(false)); - // Shared shutdown channel - // When shutting down, the main thread will wait for all senders to be dropped - let (shutdown_sender, shutdown_receiver) = mpsc::channel(); - - // Shared channel to track shard status - let (status_sender, status_receiver) = mpsc::channel(); - - // Start shard processes - for rank in 0..num_shard { - let model_id = model_id.clone(); - let revision = revision.clone(); - let uds_path = shard_uds_path.clone(); - let master_addr = master_addr.clone(); - let huggingface_hub_cache = huggingface_hub_cache.clone(); - let weights_cache_override = weights_cache_override.clone(); - let status_sender = status_sender.clone(); - let shutdown = shutdown.clone(); - let shutdown_sender = shutdown_sender.clone(); - let otlp_endpoint = otlp_endpoint.clone(); - thread::spawn(move || { - shard_manager( - model_id, - revision, - quantize, - uds_path, - rank, - num_shard, - master_addr, - master_port, - huggingface_hub_cache, - weights_cache_override, - disable_custom_kernels, - watermark_gamma, - watermark_delta, - otlp_endpoint, - status_sender, - shutdown, - shutdown_sender, - ) - }); - } - drop(shutdown_sender); - - // Wait for shard to start - let mut shard_ready = 0; - while running.load(Ordering::SeqCst) { - match status_receiver.try_recv() { - Ok(ShardStatus::Ready) => { - shard_ready += 1; - if shard_ready == num_shard { - break; - } - } - Err(TryRecvError::Empty) => { - sleep(Duration::from_millis(100)); - } - Ok(ShardStatus::Failed((rank, err))) => { - tracing::error!("Shard {} failed to start:\n{}", rank, err); - shutdown_shards(shutdown, &shutdown_receiver); - return ExitCode::FAILURE; - } - Err(TryRecvError::Disconnected) => { - tracing::error!("Shard status channel disconnected"); - shutdown_shards(shutdown, &shutdown_receiver); - return ExitCode::FAILURE; - } - } - } - - // We might have received a termination signal - if !running.load(Ordering::SeqCst) { - shutdown_shards(shutdown, &shutdown_receiver); - return ExitCode::SUCCESS; - } - - // All shard started - // Start webserver - tracing::info!("Starting Webserver"); - let mut argv = vec![ - "text-generation-router".to_string(), - "--max-concurrent-requests".to_string(), - max_concurrent_requests.to_string(), - "--max-best-of".to_string(), - max_best_of.to_string(), - "--max-stop-sequences".to_string(), - max_stop_sequences.to_string(), - "--max-input-length".to_string(), - max_input_length.to_string(), - "--max-total-tokens".to_string(), - max_total_tokens.to_string(), - "--waiting-served-ratio".to_string(), - waiting_served_ratio.to_string(), - "--max-waiting-tokens".to_string(), - max_waiting_tokens.to_string(), - "--port".to_string(), - port.to_string(), - "--master-shard-uds-path".to_string(), - format!("{shard_uds_path}-0"), - "--tokenizer-name".to_string(), - model_id, - ]; - - // Deprecate max_batch_size - if let Some(max_batch_size) = max_batch_size { - argv.push("--max-batch-size".to_string()); - argv.push(max_batch_size.to_string()) - } else { - argv.push("--max-batch-total-tokens".to_string()); - argv.push(max_batch_total_tokens.to_string()) - } - - // Model optional revision - if let Some(ref revision) = revision { - argv.push("--revision".to_string()); - argv.push(revision.to_string()) - } - - if json_output { - argv.push("--json-output".to_string()); - } - - // OpenTelemetry - if let Some(otlp_endpoint) = otlp_endpoint { - argv.push("--otlp-endpoint".to_string()); - argv.push(otlp_endpoint); - } - - // CORS origins - for origin in cors_allow_origin.into_iter() { - argv.push("--cors-allow-origin".to_string()); - argv.push(origin); - } - - // Copy current process env - let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); - - // Parse Inference API token - if let Ok(api_token) = env::var("HF_API_TOKEN") { - env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) - }; - - let mut webserver = match Popen::create( - &argv, - PopenConfig { - stdout: Redirection::Pipe, - stderr: Redirection::Pipe, - // Needed for the shutdown procedure - setpgid: true, - env: Some(env), - ..Default::default() - }, - ) { - Ok(p) => p, - Err(err) => { - tracing::error!("Failed to start webserver: {}", err); - if let PopenError::IoError(err) = err { - if err.kind() == io::ErrorKind::NotFound { - tracing::error!("text-generation-router not found in PATH"); - tracing::error!("Please install it with `make install-router`") - } - } else { - tracing::error!("{}", err); - } - - shutdown_shards(shutdown, &shutdown_receiver); - return ExitCode::FAILURE; - } - }; - - // Redirect STDOUT and STDERR to the console - let webserver_stdout = webserver.stdout.take().unwrap(); - let webserver_stderr = webserver.stderr.take().unwrap(); - - thread::spawn(move || { - let stdout = BufReader::new(webserver_stdout); - let stderr = BufReader::new(webserver_stderr); - for line in stdout.lines() { - println!("{}", line.unwrap()); - } - for line in stderr.lines() { - println!("{}", line.unwrap()); - } - }); - - // Default exit code - let mut exit_code = ExitCode::SUCCESS; - - while running.load(Ordering::SeqCst) { - if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() { - tracing::error!("Shard {rank} failed:\n{err}"); - exit_code = ExitCode::FAILURE; - break; - }; - - match webserver.poll() { - Some(_) => { - tracing::error!("Webserver Crashed"); - shutdown_shards(shutdown, &shutdown_receiver); - return ExitCode::FAILURE; - } - None => { - sleep(Duration::from_millis(100)); - } - }; - } - - // Graceful termination - webserver.terminate().unwrap(); - tracing::info!("Waiting for webserver to gracefully shutdown"); - webserver.wait_timeout(Duration::from_secs(90)).unwrap(); - tracing::info!("Webserver terminated"); - shutdown_shards(shutdown, &shutdown_receiver); - - exit_code -} - #[derive(Debug)] enum ShardStatus { Ready, @@ -774,3 +334,450 @@ impl PythonLogMessage { } } } + +fn find_num_shards(sharded: Option, num_shard: Option) -> usize { + // get the number of shards given `sharded` and `num_shard` + let num_shard = match (sharded, num_shard) { + (Some(true), None) => { + // try to default to the number of available GPUs + tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES"); + let n_devices = + num_cuda_devices().expect("--num-shard and CUDA_VISIBLE_DEVICES are not set"); + if n_devices <= 1 { + panic!("`sharded` is true but only found {n_devices} CUDA devices"); + } + n_devices + } + (Some(true), Some(num_shard)) => { + // we can't have only one shard while sharded + if num_shard <= 1 { + panic!("`sharded` is true but `num_shard` <= 1"); + } + num_shard + } + (Some(false), Some(num_shard)) => num_shard, + (Some(false), None) => 1, + (None, None) => num_cuda_devices().unwrap_or(1), + (None, Some(num_shard)) => num_shard, + }; + if num_shard < 1 { + panic!("`num_shard` cannot be < 1"); + } + num_shard +} + +#[derive(Debug)] +enum LauncherError { + DownloadError, + ShardCannotStart, + ShardDisconnected, + ShardFailed, + WebserverFailed, + WebserverCannotStart, +} + +fn download_model(args: &Args, running: Arc) -> Result<(), LauncherError> { + let mut download_argv = vec![ + "text-generation-server".to_string(), + "download-weights".to_string(), + args.model_id.to_string(), + "--extension".to_string(), + ".safetensors".to_string(), + "--logger-level".to_string(), + "INFO".to_string(), + "--json-output".to_string(), + ]; + + // Model optional revision + if let Some(revision) = &args.revision { + download_argv.push("--revision".to_string()); + download_argv.push(revision.to_string()) + } + + // Copy current process env + let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); + + // If huggingface_hub_cache is set, pass it to the shard + // Useful when running inside a docker container + if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache { + env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); + }; + + // Enable hf transfer for insane download speeds + let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); + env.push(( + "HF_HUB_ENABLE_HF_TRANSFER".into(), + enable_hf_transfer.into(), + )); + + // Parse Inference API token + if let Ok(api_token) = env::var("HF_API_TOKEN") { + env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) + }; + + // Start process + tracing::info!("Starting download process."); + let mut download_process = match Popen::create( + &download_argv, + PopenConfig { + stdout: Redirection::Pipe, + stderr: Redirection::Pipe, + // Needed for the shutdown procedure + setpgid: true, + env: Some(env), + ..Default::default() + }, + ) { + Ok(p) => p, + Err(err) => { + if let PopenError::IoError(ref err) = err { + if err.kind() == io::ErrorKind::NotFound { + tracing::error!("text-generation-server not found in PATH"); + tracing::error!("Please install it with `make install-server`") + } + } + return Err(LauncherError::DownloadError); + } + }; + + // Redirect STDOUT to the console + let download_stdout = download_process.stdout.take().unwrap(); + thread::spawn(move || { + // Enter download tracing span + let stdout = BufReader::new(download_stdout); + let _span = tracing::span!(tracing::Level::INFO, "download").entered(); + for line in stdout.lines() { + // Parse loguru logs + if let Ok(log) = serde_json::from_str::(&line.unwrap()) { + log.trace(); + } + } + }); + + loop { + if let Some(status) = download_process.poll() { + match status { + ExitStatus::Exited(exit_code) => { + if exit_code == 0 { + tracing::info!("Successfully downloaded weights."); + break; + } else { + let mut err = String::new(); + download_process + .stderr + .take() + .unwrap() + .read_to_string(&mut err) + .unwrap(); + tracing::error!("Download encountered an error: {err}"); + return Err(LauncherError::DownloadError); + } + } + _ => { + tracing::error!("Download process exited with an unknown status."); + return Err(LauncherError::DownloadError); + } + } + } + if !running.load(Ordering::SeqCst) { + download_process.terminate().unwrap(); + tracing::info!("Waiting for download process to gracefully shutdown"); + download_process + .wait_timeout(Duration::from_secs(90)) + .unwrap(); + tracing::info!("Download process terminated"); + return Ok(()); + } + sleep(Duration::from_millis(100)); + } + Ok(()) +} + +fn spawn_shards( + num_shard: usize, + args: &Args, + shutdown: Arc>, + shutdown_receiver: &mpsc::Receiver<()>, + shutdown_sender: mpsc::Sender<()>, + status_receiver: &mpsc::Receiver, + status_sender: mpsc::Sender, + running: Arc, +) -> Result<(), LauncherError> { + // Start shard processes + for rank in 0..num_shard { + let model_id = args.model_id.clone(); + let revision = args.revision.clone(); + let uds_path = args.shard_uds_path.clone(); + let master_addr = args.master_addr.clone(); + let huggingface_hub_cache = args.huggingface_hub_cache.clone(); + let weights_cache_override = args.weights_cache_override.clone(); + let status_sender = status_sender.clone(); + let shutdown = shutdown.clone(); + let shutdown_sender = shutdown_sender.clone(); + let otlp_endpoint = args.otlp_endpoint.clone(); + let quantize = args.quantize.clone(); + let master_port = args.master_port.clone(); + let disable_custom_kernels = args.disable_custom_kernels.clone(); + let watermark_gamma = args.watermark_gamma.clone(); + let watermark_delta = args.watermark_delta.clone(); + thread::spawn(move || { + shard_manager( + model_id, + revision, + quantize, + uds_path, + rank, + num_shard, + master_addr, + master_port, + huggingface_hub_cache, + weights_cache_override, + disable_custom_kernels, + watermark_gamma, + watermark_delta, + otlp_endpoint, + status_sender, + shutdown, + shutdown_sender, + ) + }); + } + drop(shutdown_sender); + + // Wait for shard to start + let mut shard_ready = 0; + while running.load(Ordering::SeqCst) { + match status_receiver.try_recv() { + Ok(ShardStatus::Ready) => { + shard_ready += 1; + if shard_ready == num_shard { + break; + } + } + Err(TryRecvError::Empty) => { + sleep(Duration::from_millis(100)); + } + Ok(ShardStatus::Failed((rank, err))) => { + tracing::error!("Shard {} failed to start:\n{}", rank, err); + shutdown_shards(shutdown, &shutdown_receiver); + return Err(LauncherError::ShardCannotStart); + } + Err(TryRecvError::Disconnected) => { + tracing::error!("Shard status channel disconnected"); + shutdown_shards(shutdown, &shutdown_receiver); + return Err(LauncherError::ShardDisconnected); + } + } + } + Ok(()) +} + +fn spawn_webserver( + args: Args, + shutdown: Arc>, + shutdown_receiver: &mpsc::Receiver<()>, +) -> Result { + // All shard started + // Start webserver + tracing::info!("Starting Webserver"); + let mut argv = vec![ + "text-generation-router".to_string(), + "--max-concurrent-requests".to_string(), + args.max_concurrent_requests.to_string(), + "--max-best-of".to_string(), + args.max_best_of.to_string(), + "--max-stop-sequences".to_string(), + args.max_stop_sequences.to_string(), + "--max-input-length".to_string(), + args.max_input_length.to_string(), + "--max-total-tokens".to_string(), + args.max_total_tokens.to_string(), + "--waiting-served-ratio".to_string(), + args.waiting_served_ratio.to_string(), + "--max-waiting-tokens".to_string(), + args.max_waiting_tokens.to_string(), + "--port".to_string(), + args.port.to_string(), + "--master-shard-uds-path".to_string(), + format!("{}-0", args.shard_uds_path), + "--tokenizer-name".to_string(), + args.model_id, + ]; + + // Deprecate max_batch_size + if let Some(max_batch_size) = args.max_batch_size { + argv.push("--max-batch-size".to_string()); + argv.push(max_batch_size.to_string()) + } else { + argv.push("--max-batch-total-tokens".to_string()); + argv.push(args.max_batch_total_tokens.to_string()) + } + + // Model optional revision + if let Some(ref revision) = args.revision { + argv.push("--revision".to_string()); + argv.push(revision.to_string()) + } + + if args.json_output { + argv.push("--json-output".to_string()); + } + + // OpenTelemetry + if let Some(otlp_endpoint) = args.otlp_endpoint { + argv.push("--otlp-endpoint".to_string()); + argv.push(otlp_endpoint); + } + + // CORS origins + for origin in args.cors_allow_origin.into_iter() { + argv.push("--cors-allow-origin".to_string()); + argv.push(origin); + } + + // Copy current process env + let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); + + // Parse Inference API token + if let Ok(api_token) = env::var("HF_API_TOKEN") { + env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) + }; + + let mut webserver = match Popen::create( + &argv, + PopenConfig { + stdout: Redirection::Pipe, + stderr: Redirection::Pipe, + // Needed for the shutdown procedure + setpgid: true, + env: Some(env), + ..Default::default() + }, + ) { + Ok(p) => p, + Err(err) => { + tracing::error!("Failed to start webserver: {}", err); + if let PopenError::IoError(err) = err { + if err.kind() == io::ErrorKind::NotFound { + tracing::error!("text-generation-router not found in PATH"); + tracing::error!("Please install it with `make install-router`") + } + } else { + tracing::error!("{}", err); + } + + shutdown_shards(shutdown, &shutdown_receiver); + return Err(LauncherError::WebserverCannotStart); + } + }; + + // Redirect STDOUT and STDERR to the console + let webserver_stdout = webserver.stdout.take().unwrap(); + let webserver_stderr = webserver.stderr.take().unwrap(); + + thread::spawn(move || { + let stdout = BufReader::new(webserver_stdout); + let stderr = BufReader::new(webserver_stderr); + for line in stdout.lines() { + println!("{}", line.unwrap()); + } + for line in stderr.lines() { + println!("{}", line.unwrap()); + } + }); + Ok(webserver) +} + +fn main() -> Result<(), LauncherError> { + // Pattern match configuration + let args = Args::parse(); + + if args.json_output { + tracing_subscriber::fmt().json().init(); + } else { + tracing_subscriber::fmt().compact().init(); + } + + tracing::info!("{:?}", args); + + let num_shard = find_num_shards(args.sharded, args.num_shard); + if num_shard > 1 { + tracing::info!("Sharding model on {num_shard} processes"); + } + + // Signal handler + let running = Arc::new(AtomicBool::new(true)); + let r = running.clone(); + ctrlc::set_handler(move || { + r.store(false, Ordering::SeqCst); + }) + .expect("Error setting Ctrl-C handler"); + + // Check if model_id is a local model + let local_path = Path::new(&args.model_id); + let is_local_model = local_path.exists() && local_path.is_dir(); + + // Download weights for sharded models + if !is_local_model && args.weights_cache_override.is_none() && num_shard > 1 { + download_model(&args, running.clone())?; + } + + // Shared shutdown bool + let shutdown = Arc::new(Mutex::new(false)); + // Shared shutdown channel + // When shutting down, the main thread will wait for all senders to be dropped + let (shutdown_sender, shutdown_receiver) = mpsc::channel(); + + // Shared channel to track shard status + let (status_sender, status_receiver) = mpsc::channel(); + + spawn_shards( + num_shard, + &args, + shutdown.clone(), + &shutdown_receiver, + shutdown_sender, + &status_receiver, + status_sender, + running.clone(), + )?; + + // We might have received a termination signal + if !running.load(Ordering::SeqCst) { + shutdown_shards(shutdown, &shutdown_receiver); + return Ok(()); + } + + let mut webserver = spawn_webserver(args, shutdown.clone(), &shutdown_receiver)?; + + // Default exit code + let mut exit_code = Ok(()); + + while running.load(Ordering::SeqCst) { + if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() { + tracing::error!("Shard {rank} failed:\n{err}"); + exit_code = Err(LauncherError::ShardFailed); + break; + }; + + match webserver.poll() { + Some(_) => { + tracing::error!("Webserver Crashed"); + shutdown_shards(shutdown, &shutdown_receiver); + return Err(LauncherError::WebserverFailed); + } + None => { + sleep(Duration::from_millis(100)); + } + }; + } + + // Graceful termination + webserver.terminate().unwrap(); + tracing::info!("Waiting for webserver to gracefully shutdown"); + webserver.wait_timeout(Duration::from_secs(90)).unwrap(); + tracing::info!("Webserver terminated"); + shutdown_shards(shutdown, &shutdown_receiver); + + exit_code +} From c4fb09f2aef0730d88ce7e4ce54eacd7839b48cc Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 26 Apr 2023 16:14:40 +0200 Subject: [PATCH 2/2] feat(router): add tests to validation (#237) --- .github/workflows/tests.yaml | 3 ++ router/src/lib.rs | 16 +++++++ router/src/queue.rs | 11 +++++ router/src/server.rs | 1 + router/src/validation.rs | 81 +++++++++++++++++++++++++++++++----- 5 files changed, 102 insertions(+), 10 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index e82e8b20..a2c2b7fb 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -67,6 +67,9 @@ jobs: run: | pip install pytest HF_HUB_ENABLE_HF_TRANSFER=1 pytest -sv server/tests + - name: Run Clippy + run: | + cargo clippy - name: Run Rust tests run: | cargo test diff --git a/router/src/lib.rs b/router/src/lib.rs index 7a1707d9..85b13cfa 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -276,3 +276,19 @@ pub(crate) struct ErrorResponse { pub error: String, pub error_type: String, } + +#[cfg(test)] +mod tests{ + use std::io::Write; + use tokenizers::Tokenizer; + + pub(crate) async fn get_tokenizer() -> Tokenizer{ + if !std::path::Path::new("tokenizer.json").exists(){ + let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap(); + let mut file = std::fs::File::create("tokenizer.json").unwrap(); + file.write_all(&content).unwrap(); + } + Tokenizer::from_file("tokenizer.json").unwrap() + } +} + diff --git a/router/src/queue.rs b/router/src/queue.rs index d970ebf1..d3f118d8 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -141,6 +141,7 @@ impl State { // Get the next batch fn next_batch(&mut self, min_size: Option, token_budget: u32) -> Option { + if self.entries.is_empty() { return None; } @@ -430,7 +431,17 @@ mod tests { let (entry3, _guard3) = default_entry(); queue.append(entry3); + // Not enough requests pending assert!(queue.next_batch(Some(2), 2).await.is_none()); + // Not enough token budget + assert!(queue.next_batch(Some(1), 0).await.is_none()); + // Ok + let (entries2, batch2, _) = queue.next_batch(Some(1), 2).await.unwrap(); + assert_eq!(entries2.len(), 1); + assert!(entries2.contains_key(&2)); + assert!(entries2.get(&2).unwrap().batch_time.is_some()); + assert_eq!(batch2.id, 1); + assert_eq!(batch2.size, 1); } #[tokio::test] diff --git a/router/src/server.rs b/router/src/server.rs index 9540ba18..09b5c3ba 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -741,3 +741,4 @@ impl From for Event { .unwrap() } } + diff --git a/router/src/validation.rs b/router/src/validation.rs index 983c2612..ff2fe89d 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -382,7 +382,8 @@ pub enum ValidationError { #[cfg(test)] mod tests{ use super::*; - use std::io::Write; + use crate::default_parameters; + use crate::tests::get_tokenizer; #[tokio::test] async fn test_validation_max_new_tokens(){ @@ -401,15 +402,6 @@ mod tests{ } } - async fn get_tokenizer() -> Tokenizer{ - if !std::path::Path::new("tokenizer.json").exists(){ - let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap(); - let mut file = std::fs::File::create("tokenizer.json").unwrap(); - file.write_all(&content).unwrap(); - } - Tokenizer::from_file("tokenizer.json").unwrap() - } - #[tokio::test] async fn test_validation_input_length(){ let tokenizer = Some(get_tokenizer().await); @@ -426,4 +418,73 @@ mod tests{ _ => panic!("Unexpected not max new tokens") } } + + #[tokio::test] + async fn test_validation_best_of_sampling(){ + let tokenizer = Some(get_tokenizer().await); + let max_best_of = 2; + let max_stop_sequence = 3; + let max_input_length = 4; + let max_total_tokens = 5; + let workers = 1; + let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); + match validation.validate(GenerateRequest{ + inputs: "Hello".to_string(), + parameters: GenerateParameters{ + best_of: Some(2), + do_sample: false, + ..default_parameters() + } + }).await{ + Err(ValidationError::BestOfSampling) => (), + _ => panic!("Unexpected not best of sampling") + } + + } + + #[tokio::test] + async fn test_validation_top_p(){ + let tokenizer = Some(get_tokenizer().await); + let max_best_of = 2; + let max_stop_sequence = 3; + let max_input_length = 4; + let max_total_tokens = 5; + let workers = 1; + let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); + match validation.validate(GenerateRequest{ + inputs: "Hello".to_string(), + parameters: GenerateParameters{ + top_p: Some(1.0), + ..default_parameters() + } + }).await{ + Err(ValidationError::TopP) => (), + _ => panic!("Unexpected top_p") + } + + match validation.validate(GenerateRequest{ + inputs: "Hello".to_string(), + parameters: GenerateParameters{ + top_p: Some(0.99), + max_new_tokens: 1, + ..default_parameters() + } + }).await{ + Ok(_) => (), + _ => panic!("Unexpected top_p error") + } + + let valid_request = validation.validate(GenerateRequest{ + inputs: "Hello".to_string(), + parameters: GenerateParameters{ + top_p: None, + max_new_tokens: 1, + ..default_parameters() + } + }).await.unwrap(); + // top_p == 1.0 is invalid for users to ask for but it's the default resolved value. + assert_eq!(valid_request.parameters.top_p, 1.0); + + + } }