From ac0be8a6a4d58b4800918aa22d173679e1e2984b Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 25 Jan 2024 18:16:03 +0100 Subject: [PATCH] fix: read stderr in download (#1486) #1186 --- launcher/src/main.rs | 48 +++++++++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 25c780ed..95256178 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -6,7 +6,7 @@ use nix::unistd::Pid; use serde::Deserialize; use std::env; use std::ffi::OsString; -use std::io::{BufRead, BufReader, Lines, Read}; +use std::io::{BufRead, BufReader, Lines}; use std::os::unix::process::{CommandExt, ExitStatusExt}; use std::path::Path; use std::process::{Child, Command, ExitStatus, Stdio}; @@ -497,6 +497,9 @@ fn shard_manager( // Safetensors load fast envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into())); + // Disable progress bar + envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into())); + // Enable hf transfer for insane download speeds let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); envs.push(( @@ -581,6 +584,13 @@ fn shard_manager( thread::spawn(move || { log_lines(shard_stdout_reader.lines()); }); + // We read stderr in another thread as it seems that lines() can block in some cases + let (err_sender, err_receiver) = mpsc::channel(); + thread::spawn(move || { + for line in shard_stderr_reader.lines().flatten() { + err_sender.send(line).unwrap_or(()); + } + }); let mut ready = false; let start_time = Instant::now(); @@ -588,13 +598,6 @@ fn shard_manager( loop { // Process exited if let Some(exit_status) = p.try_wait().unwrap() { - // We read stderr in another thread as it seems that lines() can block in some cases - let (err_sender, err_receiver) = mpsc::channel(); - thread::spawn(move || { - for line in shard_stderr_reader.lines().flatten() { - err_sender.send(line).unwrap_or(()); - } - }); let mut err = String::new(); while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) { err = err + "\n" + &line; @@ -790,6 +793,9 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L // Copy current process env let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); + // Disable progress bar + envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into())); + // If huggingface_hub_cache is set, pass it to the download process // Useful when running inside a docker container if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache { @@ -840,12 +846,20 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L } }; - // Redirect STDOUT to the console - let download_stdout = download_process.stdout.take().unwrap(); - let stdout = BufReader::new(download_stdout); + let download_stdout = BufReader::new(download_process.stdout.take().unwrap()); thread::spawn(move || { - log_lines(stdout.lines()); + log_lines(download_stdout.lines()); + }); + + let download_stderr = BufReader::new(download_process.stderr.take().unwrap()); + + // We read stderr in another thread as it seems that lines() can block in some cases + let (err_sender, err_receiver) = mpsc::channel(); + thread::spawn(move || { + for line in download_stderr.lines().flatten() { + err_sender.send(line).unwrap_or(()); + } }); loop { @@ -856,12 +870,10 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L } let mut err = String::new(); - download_process - .stderr - .take() - .unwrap() - .read_to_string(&mut err) - .unwrap(); + while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) { + err = err + "\n" + &line; + } + if let Some(signal) = status.signal() { tracing::error!( "Download process was signaled to shutdown with signal {signal}: {err}"