mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
feat(server): improve download logging (#66)
This commit is contained in:
parent
0fbc691946
commit
c5a4a1faf3
@ -22,8 +22,8 @@ struct Args {
|
|||||||
model_id: String,
|
model_id: String,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
#[clap(long, env)]
|
#[clap(default_value = "1", long, env)]
|
||||||
num_shard: Option<usize>,
|
num_shard: usize,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
quantize: bool,
|
quantize: bool,
|
||||||
#[clap(default_value = "128", long, env)]
|
#[clap(default_value = "128", long, env)]
|
||||||
@ -54,6 +54,16 @@ struct Args {
|
|||||||
|
|
||||||
fn main() -> ExitCode {
|
fn main() -> ExitCode {
|
||||||
// Pattern match configuration
|
// Pattern match configuration
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
if args.json_output {
|
||||||
|
tracing_subscriber::fmt().json().init();
|
||||||
|
} else {
|
||||||
|
tracing_subscriber::fmt().compact().init();
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!("{:?}", args);
|
||||||
|
|
||||||
let Args {
|
let Args {
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -71,16 +81,7 @@ fn main() -> ExitCode {
|
|||||||
weights_cache_override,
|
weights_cache_override,
|
||||||
json_output,
|
json_output,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
} = Args::parse();
|
} = args;
|
||||||
|
|
||||||
if json_output {
|
|
||||||
tracing_subscriber::fmt().json().init();
|
|
||||||
} else {
|
|
||||||
tracing_subscriber::fmt().compact().init();
|
|
||||||
}
|
|
||||||
|
|
||||||
// By default we only have one master shard
|
|
||||||
let num_shard = num_shard.unwrap_or(1);
|
|
||||||
|
|
||||||
// Signal handler
|
// Signal handler
|
||||||
let running = Arc::new(AtomicBool::new(true));
|
let running = Arc::new(AtomicBool::new(true));
|
||||||
@ -123,7 +124,7 @@ fn main() -> ExitCode {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Start process
|
// Start process
|
||||||
tracing::info!("Starting download");
|
tracing::info!("Starting download process.");
|
||||||
let mut download_process = match Popen::create(
|
let mut download_process = match Popen::create(
|
||||||
&download_argv,
|
&download_argv,
|
||||||
PopenConfig {
|
PopenConfig {
|
||||||
@ -184,7 +185,7 @@ fn main() -> ExitCode {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
tracing::error!("Download process exited with an unkown status.");
|
tracing::error!("Download process exited with an unknown status.");
|
||||||
return ExitCode::FAILURE;
|
return ExitCode::FAILURE;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -83,14 +83,10 @@ def convert_files(pt_files: List[Path], st_files: List[Path]):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
||||||
logger.info("Converting weights...")
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for i, future in enumerate(concurrent.futures.as_completed(futures)):
|
for i, future in enumerate(concurrent.futures.as_completed(futures)):
|
||||||
elapsed = timedelta(seconds=int(time.time() - start_time))
|
elapsed = timedelta(seconds=int(time.time() - start_time))
|
||||||
remaining = len(futures) - (i + 1)
|
remaining = len(futures) - (i + 1)
|
||||||
if remaining != 0:
|
eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0
|
||||||
eta = (elapsed / (i + 1)) * remaining
|
|
||||||
else:
|
|
||||||
eta = 0
|
|
||||||
|
|
||||||
logger.info(f"Convert: [{i + 1}/{len(futures)}] -- ETA: {eta}")
|
logger.info(f"Convert: [{i + 1}/{len(futures)}] -- ETA: {eta}")
|
||||||
|
@ -134,6 +134,7 @@ def download_weights(
|
|||||||
logger.info(f"File {filename} already present in cache.")
|
logger.info(f"File {filename} already present in cache.")
|
||||||
return local_file
|
return local_file
|
||||||
|
|
||||||
|
logger.info(f"Starting {filename} download.")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
local_file = hf_hub_download(
|
local_file = hf_hub_download(
|
||||||
filename=filename,
|
filename=filename,
|
||||||
@ -144,7 +145,7 @@ def download_weights(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Downloaded {filename} at {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
|
f"Downloaded {filename} at {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
|
||||||
)
|
)
|
||||||
return local_file
|
return Path(local_file)
|
||||||
|
|
||||||
executor = ThreadPoolExecutor(max_workers=5)
|
executor = ThreadPoolExecutor(max_workers=5)
|
||||||
futures = [
|
futures = [
|
||||||
@ -152,18 +153,14 @@ def download_weights(
|
|||||||
]
|
]
|
||||||
|
|
||||||
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
||||||
logger.info("Downloading weights...")
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
files = []
|
files = []
|
||||||
for i, future in enumerate(concurrent.futures.as_completed(futures)):
|
for i, future in enumerate(concurrent.futures.as_completed(futures)):
|
||||||
elapsed = timedelta(seconds=int(time.time() - start_time))
|
elapsed = timedelta(seconds=int(time.time() - start_time))
|
||||||
remaining = len(futures) - (i + 1)
|
remaining = len(futures) - (i + 1)
|
||||||
if remaining != 0:
|
eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0
|
||||||
eta = (elapsed / (i + 1)) * remaining
|
|
||||||
else:
|
|
||||||
eta = 0
|
|
||||||
|
|
||||||
logger.info(f"Download: [{i + 1}/{len(futures)}] -- ETA: {eta}")
|
logger.info(f"Download: [{i + 1}/{len(futures)}] -- ETA: {eta}")
|
||||||
files.append(Path(future.result()))
|
files.append(future.result())
|
||||||
|
|
||||||
return [Path(p) for p in files]
|
return files
|
||||||
|
Loading…
Reference in New Issue
Block a user