mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
Cleanup up a bit the launcher.
This commit is contained in:
parent
9df67b35bf
commit
390ec5aea8
@ -572,65 +572,11 @@ fn spawn_shards(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<(), LauncherError> {
|
fn spawn_webserver(
|
||||||
// Pattern match configuration
|
args: Args,
|
||||||
let args = Args::parse();
|
shutdown: Arc<Mutex<bool>>,
|
||||||
|
shutdown_receiver: &mpsc::Receiver<()>,
|
||||||
if args.json_output {
|
) -> Result<Popen, LauncherError> {
|
||||||
tracing_subscriber::fmt().json().init();
|
|
||||||
} else {
|
|
||||||
tracing_subscriber::fmt().compact().init();
|
|
||||||
}
|
|
||||||
|
|
||||||
tracing::info!("{:?}", args);
|
|
||||||
|
|
||||||
let num_shard = find_num_shards(args.sharded, args.num_shard);
|
|
||||||
if num_shard > 1 {
|
|
||||||
tracing::info!("Sharding model on {num_shard} processes");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Signal handler
|
|
||||||
let running = Arc::new(AtomicBool::new(true));
|
|
||||||
let r = running.clone();
|
|
||||||
ctrlc::set_handler(move || {
|
|
||||||
r.store(false, Ordering::SeqCst);
|
|
||||||
})
|
|
||||||
.expect("Error setting Ctrl-C handler");
|
|
||||||
|
|
||||||
// Check if model_id is a local model
|
|
||||||
let local_path = Path::new(&args.model_id);
|
|
||||||
let is_local_model = local_path.exists() && local_path.is_dir();
|
|
||||||
|
|
||||||
// Download weights for sharded models
|
|
||||||
if !is_local_model && args.weights_cache_override.is_none() && num_shard > 1 {
|
|
||||||
download_model(&args, running.clone())?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shared shutdown bool
|
|
||||||
let shutdown = Arc::new(Mutex::new(false));
|
|
||||||
// Shared shutdown channel
|
|
||||||
// When shutting down, the main thread will wait for all senders to be dropped
|
|
||||||
let (shutdown_sender, shutdown_receiver) = mpsc::channel();
|
|
||||||
|
|
||||||
// Shared channel to track shard status
|
|
||||||
let (status_sender, status_receiver) = mpsc::channel();
|
|
||||||
|
|
||||||
spawn_shards(
|
|
||||||
num_shard,
|
|
||||||
&args,
|
|
||||||
shutdown.clone(),
|
|
||||||
&shutdown_receiver,
|
|
||||||
shutdown_sender,
|
|
||||||
&status_receiver,
|
|
||||||
status_sender,
|
|
||||||
running.clone(),
|
|
||||||
)?;
|
|
||||||
// We might have received a termination signal
|
|
||||||
if !running.load(Ordering::SeqCst) {
|
|
||||||
shutdown_shards(shutdown, &shutdown_receiver);
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
// All shard started
|
// All shard started
|
||||||
// Start webserver
|
// Start webserver
|
||||||
tracing::info!("Starting Webserver");
|
tracing::info!("Starting Webserver");
|
||||||
@ -739,6 +685,70 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
println!("{}", line.unwrap());
|
println!("{}", line.unwrap());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
Ok(webserver)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<(), LauncherError> {
|
||||||
|
// Pattern match configuration
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
if args.json_output {
|
||||||
|
tracing_subscriber::fmt().json().init();
|
||||||
|
} else {
|
||||||
|
tracing_subscriber::fmt().compact().init();
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!("{:?}", args);
|
||||||
|
|
||||||
|
let num_shard = find_num_shards(args.sharded, args.num_shard);
|
||||||
|
if num_shard > 1 {
|
||||||
|
tracing::info!("Sharding model on {num_shard} processes");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Signal handler
|
||||||
|
let running = Arc::new(AtomicBool::new(true));
|
||||||
|
let r = running.clone();
|
||||||
|
ctrlc::set_handler(move || {
|
||||||
|
r.store(false, Ordering::SeqCst);
|
||||||
|
})
|
||||||
|
.expect("Error setting Ctrl-C handler");
|
||||||
|
|
||||||
|
// Check if model_id is a local model
|
||||||
|
let local_path = Path::new(&args.model_id);
|
||||||
|
let is_local_model = local_path.exists() && local_path.is_dir();
|
||||||
|
|
||||||
|
// Download weights for sharded models
|
||||||
|
if !is_local_model && args.weights_cache_override.is_none() && num_shard > 1 {
|
||||||
|
download_model(&args, running.clone())?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shared shutdown bool
|
||||||
|
let shutdown = Arc::new(Mutex::new(false));
|
||||||
|
// Shared shutdown channel
|
||||||
|
// When shutting down, the main thread will wait for all senders to be dropped
|
||||||
|
let (shutdown_sender, shutdown_receiver) = mpsc::channel();
|
||||||
|
|
||||||
|
// Shared channel to track shard status
|
||||||
|
let (status_sender, status_receiver) = mpsc::channel();
|
||||||
|
|
||||||
|
spawn_shards(
|
||||||
|
num_shard,
|
||||||
|
&args,
|
||||||
|
shutdown.clone(),
|
||||||
|
&shutdown_receiver,
|
||||||
|
shutdown_sender,
|
||||||
|
&status_receiver,
|
||||||
|
status_sender,
|
||||||
|
running.clone(),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// We might have received a termination signal
|
||||||
|
if !running.load(Ordering::SeqCst) {
|
||||||
|
shutdown_shards(shutdown, &shutdown_receiver);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut webserver = spawn_webserver(args, shutdown.clone(), &shutdown_receiver)?;
|
||||||
|
|
||||||
// Default exit code
|
// Default exit code
|
||||||
let mut exit_code = Ok(());
|
let mut exit_code = Ok(());
|
||||||
|
Loading…
Reference in New Issue
Block a user