From 90341055535c41e53b527e5ec2af746779ca20ff Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 14 Feb 2023 12:35:36 +0100 Subject: [PATCH] update launcher --- launcher/src/main.rs | 62 ++++++++++++++++++++++++++++---------------- 1 file changed, 39 insertions(+), 23 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index b5ef5592..0848dd9a 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -43,6 +43,10 @@ struct Args { #[clap(default_value = "29500", long, env)] master_port: usize, #[clap(long, env)] + huggingface_hub_cache: Option, + #[clap(long, env)] + weights_cache_override: Option, + #[clap(long, env)] json_output: bool, #[clap(long, env)] otlp_endpoint: Option, @@ -63,6 +67,8 @@ fn main() -> ExitCode { shard_uds_path, master_addr, master_port, + huggingface_hub_cache, + weights_cache_override, json_output, otlp_endpoint, } = Args::parse(); @@ -85,8 +91,7 @@ fn main() -> ExitCode { .expect("Error setting Ctrl-C handler"); // Download weights - if num_shard > 1 { - // Only download weights if in sharded mode + if weights_cache_override.is_none() { let mut download_argv = vec![ "text-generation-server".to_string(), "download-weights".to_string(), @@ -95,29 +100,28 @@ fn main() -> ExitCode { "INFO".to_string(), "--json-output".to_string(), ]; + if num_shard == 1 { + download_argv.push("--extension".to_string()); + download_argv.push(".bin".to_string()); + } else { + download_argv.push("--extension".to_string()); + download_argv.push(".safetensors".to_string()); + } + // Model optional revision - if let Some(revision) = revision.clone() { + if let Some(ref revision) = revision { download_argv.push("--revision".to_string()); - download_argv.push(revision) + download_argv.push(revision.to_string()) } let mut env = Vec::new(); // If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard // Useful when running inside a docker container - if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") { + if let Some(ref huggingface_hub_cache) = huggingface_hub_cache { env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); }; - // If the WEIGHTS_CACHE_OVERRIDE env var is set, pass it to the shard - // Useful when running inside a HuggingFace Inference Endpoint - if let Ok(weights_cache_override) = env::var("WEIGHTS_CACHE_OVERRIDE") { - env.push(( - "WEIGHTS_CACHE_OVERRIDE".into(), - weights_cache_override.into(), - )); - }; - // Start process tracing::info!("Starting download"); let mut download_process = match Popen::create( @@ -196,6 +200,12 @@ fn main() -> ExitCode { } sleep(Duration::from_millis(100)); } + } else { + tracing::info!( + "weights_cache_override is set to {:?}.", + weights_cache_override + ); + tracing::info!("Skipping download.") } // Shared shutdown bool @@ -213,6 +223,8 @@ fn main() -> ExitCode { let revision = revision.clone(); let uds_path = shard_uds_path.clone(); let master_addr = master_addr.clone(); + let huggingface_hub_cache = huggingface_hub_cache.clone(); + let weights_cache_override = weights_cache_override.clone(); let status_sender = status_sender.clone(); let shutdown = shutdown.clone(); let shutdown_sender = shutdown_sender.clone(); @@ -227,6 +239,8 @@ fn main() -> ExitCode { num_shard, master_addr, master_port, + huggingface_hub_cache, + weights_cache_override, otlp_endpoint, status_sender, shutdown, @@ -346,7 +360,7 @@ fn main() -> ExitCode { while running.load(Ordering::SeqCst) { if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() { - tracing::error!("Shard {} failed:\n{}", rank, err); + tracing::error!("Shard {rank} failed:\n{err}"); exit_code = ExitCode::FAILURE; break; }; @@ -389,6 +403,8 @@ fn shard_manager( world_size: usize, master_addr: String, master_port: usize, + huggingface_hub_cache: Option, + weights_cache_override: Option, otlp_endpoint: Option, status_sender: mpsc::Sender, shutdown: Arc>, @@ -442,15 +458,15 @@ fn shard_manager( ("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()), ]; - // If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard + // If huggingface_hub_cache is some, pass it to the shard // Useful when running inside a docker container - if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") { + if let Some(huggingface_hub_cache) = huggingface_hub_cache { env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); }; - // If the WEIGHTS_CACHE_OVERRIDE env var is set, pass it to the shard + // If weights_cache_override is some, pass it to the shard // Useful when running inside a HuggingFace Inference Endpoint - if let Ok(weights_cache_override) = env::var("WEIGHTS_CACHE_OVERRIDE") { + if let Some(weights_cache_override) = weights_cache_override { env.push(( "WEIGHTS_CACHE_OVERRIDE".into(), weights_cache_override.into(), @@ -469,7 +485,7 @@ fn shard_manager( }; // Start process - tracing::info!("Starting shard {}", rank); + tracing::info!("Starting shard {rank}"); let mut p = match Popen::create( &shard_argv, PopenConfig { @@ -533,17 +549,17 @@ fn shard_manager( if *shutdown.lock().unwrap() { p.terminate().unwrap(); let _ = p.wait_timeout(Duration::from_secs(90)); - tracing::info!("Shard {} terminated", rank); + tracing::info!("Shard {rank} terminated"); return; } // Shard is ready if uds.exists() && !ready { - tracing::info!("Shard {} ready in {:?}", rank, start_time.elapsed()); + tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed()); status_sender.send(ShardStatus::Ready).unwrap(); ready = true; } else if !ready && wait_time.elapsed() > Duration::from_secs(10) { - tracing::info!("Waiting for shard {} to be ready...", rank); + tracing::info!("Waiting for shard {rank} to be ready..."); wait_time = Instant::now(); } sleep(Duration::from_millis(100));