Cleanup up a bit the launcher.

This commit is contained in:
Nicolas Patry 2023-04-25 20:52:37 +02:00
parent 9df67b35bf
commit 390ec5aea8
No known key found for this signature in database
GPG Key ID: 6AE76ACC68FFBAF9

View File

@ -572,65 +572,11 @@ fn spawn_shards(
Ok(())
}
fn main() -> Result<(), LauncherError> {
// Pattern match configuration
let args = Args::parse();
if args.json_output {
tracing_subscriber::fmt().json().init();
} else {
tracing_subscriber::fmt().compact().init();
}
tracing::info!("{:?}", args);
let num_shard = find_num_shards(args.sharded, args.num_shard);
if num_shard > 1 {
tracing::info!("Sharding model on {num_shard} processes");
}
// Signal handler
let running = Arc::new(AtomicBool::new(true));
let r = running.clone();
ctrlc::set_handler(move || {
r.store(false, Ordering::SeqCst);
})
.expect("Error setting Ctrl-C handler");
// Check if model_id is a local model
let local_path = Path::new(&args.model_id);
let is_local_model = local_path.exists() && local_path.is_dir();
// Download weights for sharded models
if !is_local_model && args.weights_cache_override.is_none() && num_shard > 1 {
download_model(&args, running.clone())?;
}
// Shared shutdown bool
let shutdown = Arc::new(Mutex::new(false));
// Shared shutdown channel
// When shutting down, the main thread will wait for all senders to be dropped
let (shutdown_sender, shutdown_receiver) = mpsc::channel();
// Shared channel to track shard status
let (status_sender, status_receiver) = mpsc::channel();
spawn_shards(
num_shard,
&args,
shutdown.clone(),
&shutdown_receiver,
shutdown_sender,
&status_receiver,
status_sender,
running.clone(),
)?;
// We might have received a termination signal
if !running.load(Ordering::SeqCst) {
shutdown_shards(shutdown, &shutdown_receiver);
return Ok(());
}
fn spawn_webserver(
args: Args,
shutdown: Arc<Mutex<bool>>,
shutdown_receiver: &mpsc::Receiver<()>,
) -> Result<Popen, LauncherError> {
// All shard started
// Start webserver
tracing::info!("Starting Webserver");
@ -739,6 +685,70 @@ fn main() -> Result<(), LauncherError> {
println!("{}", line.unwrap());
}
});
Ok(webserver)
}
fn main() -> Result<(), LauncherError> {
// Pattern match configuration
let args = Args::parse();
if args.json_output {
tracing_subscriber::fmt().json().init();
} else {
tracing_subscriber::fmt().compact().init();
}
tracing::info!("{:?}", args);
let num_shard = find_num_shards(args.sharded, args.num_shard);
if num_shard > 1 {
tracing::info!("Sharding model on {num_shard} processes");
}
// Signal handler
let running = Arc::new(AtomicBool::new(true));
let r = running.clone();
ctrlc::set_handler(move || {
r.store(false, Ordering::SeqCst);
})
.expect("Error setting Ctrl-C handler");
// Check if model_id is a local model
let local_path = Path::new(&args.model_id);
let is_local_model = local_path.exists() && local_path.is_dir();
// Download weights for sharded models
if !is_local_model && args.weights_cache_override.is_none() && num_shard > 1 {
download_model(&args, running.clone())?;
}
// Shared shutdown bool
let shutdown = Arc::new(Mutex::new(false));
// Shared shutdown channel
// When shutting down, the main thread will wait for all senders to be dropped
let (shutdown_sender, shutdown_receiver) = mpsc::channel();
// Shared channel to track shard status
let (status_sender, status_receiver) = mpsc::channel();
spawn_shards(
num_shard,
&args,
shutdown.clone(),
&shutdown_receiver,
shutdown_sender,
&status_receiver,
status_sender,
running.clone(),
)?;
// We might have received a termination signal
if !running.load(Ordering::SeqCst) {
shutdown_shards(shutdown, &shutdown_receiver);
return Ok(());
}
let mut webserver = spawn_webserver(args, shutdown.clone(), &shutdown_receiver)?;
// Default exit code
let mut exit_code = Ok(());