feat(server): improve download logging (#66)

This commit is contained in:
OlivierDehaene 2023-02-15 16:11:32 +01:00 committed by GitHub
parent 0fbc691946
commit c5a4a1faf3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 27 deletions

View File

@ -22,8 +22,8 @@ struct Args {
model_id: String, model_id: String,
#[clap(long, env)] #[clap(long, env)]
revision: Option<String>, revision: Option<String>,
#[clap(long, env)] #[clap(default_value = "1", long, env)]
num_shard: Option<usize>, num_shard: usize,
#[clap(long, env)] #[clap(long, env)]
quantize: bool, quantize: bool,
#[clap(default_value = "128", long, env)] #[clap(default_value = "128", long, env)]
@ -54,6 +54,16 @@ struct Args {
fn main() -> ExitCode { fn main() -> ExitCode {
// Pattern match configuration // Pattern match configuration
let args = Args::parse();
if args.json_output {
tracing_subscriber::fmt().json().init();
} else {
tracing_subscriber::fmt().compact().init();
}
tracing::info!("{:?}", args);
let Args { let Args {
model_id, model_id,
revision, revision,
@ -71,16 +81,7 @@ fn main() -> ExitCode {
weights_cache_override, weights_cache_override,
json_output, json_output,
otlp_endpoint, otlp_endpoint,
} = Args::parse(); } = args;
if json_output {
tracing_subscriber::fmt().json().init();
} else {
tracing_subscriber::fmt().compact().init();
}
// By default we only have one master shard
let num_shard = num_shard.unwrap_or(1);
// Signal handler // Signal handler
let running = Arc::new(AtomicBool::new(true)); let running = Arc::new(AtomicBool::new(true));
@ -123,7 +124,7 @@ fn main() -> ExitCode {
}; };
// Start process // Start process
tracing::info!("Starting download"); tracing::info!("Starting download process.");
let mut download_process = match Popen::create( let mut download_process = match Popen::create(
&download_argv, &download_argv,
PopenConfig { PopenConfig {
@ -184,7 +185,7 @@ fn main() -> ExitCode {
} }
} }
_ => { _ => {
tracing::error!("Download process exited with an unkown status."); tracing::error!("Download process exited with an unknown status.");
return ExitCode::FAILURE; return ExitCode::FAILURE;
} }
} }

View File

@ -83,14 +83,10 @@ def convert_files(pt_files: List[Path], st_files: List[Path]):
] ]
# We do this instead of using tqdm because we want to parse the logs with the launcher # We do this instead of using tqdm because we want to parse the logs with the launcher
logger.info("Converting weights...")
start_time = time.time() start_time = time.time()
for i, future in enumerate(concurrent.futures.as_completed(futures)): for i, future in enumerate(concurrent.futures.as_completed(futures)):
elapsed = timedelta(seconds=int(time.time() - start_time)) elapsed = timedelta(seconds=int(time.time() - start_time))
remaining = len(futures) - (i + 1) remaining = len(futures) - (i + 1)
if remaining != 0: eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0
eta = (elapsed / (i + 1)) * remaining
else:
eta = 0
logger.info(f"Convert: [{i + 1}/{len(futures)}] -- ETA: {eta}") logger.info(f"Convert: [{i + 1}/{len(futures)}] -- ETA: {eta}")

View File

@ -134,6 +134,7 @@ def download_weights(
logger.info(f"File {filename} already present in cache.") logger.info(f"File {filename} already present in cache.")
return local_file return local_file
logger.info(f"Starting {filename} download.")
start_time = time.time() start_time = time.time()
local_file = hf_hub_download( local_file = hf_hub_download(
filename=filename, filename=filename,
@ -144,7 +145,7 @@ def download_weights(
logger.info( logger.info(
f"Downloaded {filename} at {local_file} in {timedelta(seconds=int(time.time() - start_time))}." f"Downloaded {filename} at {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
) )
return local_file return Path(local_file)
executor = ThreadPoolExecutor(max_workers=5) executor = ThreadPoolExecutor(max_workers=5)
futures = [ futures = [
@ -152,18 +153,14 @@ def download_weights(
] ]
# We do this instead of using tqdm because we want to parse the logs with the launcher # We do this instead of using tqdm because we want to parse the logs with the launcher
logger.info("Downloading weights...")
start_time = time.time() start_time = time.time()
files = [] files = []
for i, future in enumerate(concurrent.futures.as_completed(futures)): for i, future in enumerate(concurrent.futures.as_completed(futures)):
elapsed = timedelta(seconds=int(time.time() - start_time)) elapsed = timedelta(seconds=int(time.time() - start_time))
remaining = len(futures) - (i + 1) remaining = len(futures) - (i + 1)
if remaining != 0: eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0
eta = (elapsed / (i + 1)) * remaining
else:
eta = 0
logger.info(f"Download: [{i + 1}/{len(futures)}] -- ETA: {eta}") logger.info(f"Download: [{i + 1}/{len(futures)}] -- ETA: {eta}")
files.append(Path(future.result())) files.append(future.result())
return [Path(p) for p in files] return files