mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
Cleanup up a bit the launcher.
This commit is contained in:
parent
9df67b35bf
commit
390ec5aea8
@ -572,65 +572,11 @@ fn spawn_shards(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> Result<(), LauncherError> {
|
||||
// Pattern match configuration
|
||||
let args = Args::parse();
|
||||
|
||||
if args.json_output {
|
||||
tracing_subscriber::fmt().json().init();
|
||||
} else {
|
||||
tracing_subscriber::fmt().compact().init();
|
||||
}
|
||||
|
||||
tracing::info!("{:?}", args);
|
||||
|
||||
let num_shard = find_num_shards(args.sharded, args.num_shard);
|
||||
if num_shard > 1 {
|
||||
tracing::info!("Sharding model on {num_shard} processes");
|
||||
}
|
||||
|
||||
// Signal handler
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
let r = running.clone();
|
||||
ctrlc::set_handler(move || {
|
||||
r.store(false, Ordering::SeqCst);
|
||||
})
|
||||
.expect("Error setting Ctrl-C handler");
|
||||
|
||||
// Check if model_id is a local model
|
||||
let local_path = Path::new(&args.model_id);
|
||||
let is_local_model = local_path.exists() && local_path.is_dir();
|
||||
|
||||
// Download weights for sharded models
|
||||
if !is_local_model && args.weights_cache_override.is_none() && num_shard > 1 {
|
||||
download_model(&args, running.clone())?;
|
||||
}
|
||||
|
||||
// Shared shutdown bool
|
||||
let shutdown = Arc::new(Mutex::new(false));
|
||||
// Shared shutdown channel
|
||||
// When shutting down, the main thread will wait for all senders to be dropped
|
||||
let (shutdown_sender, shutdown_receiver) = mpsc::channel();
|
||||
|
||||
// Shared channel to track shard status
|
||||
let (status_sender, status_receiver) = mpsc::channel();
|
||||
|
||||
spawn_shards(
|
||||
num_shard,
|
||||
&args,
|
||||
shutdown.clone(),
|
||||
&shutdown_receiver,
|
||||
shutdown_sender,
|
||||
&status_receiver,
|
||||
status_sender,
|
||||
running.clone(),
|
||||
)?;
|
||||
// We might have received a termination signal
|
||||
if !running.load(Ordering::SeqCst) {
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
fn spawn_webserver(
|
||||
args: Args,
|
||||
shutdown: Arc<Mutex<bool>>,
|
||||
shutdown_receiver: &mpsc::Receiver<()>,
|
||||
) -> Result<Popen, LauncherError> {
|
||||
// All shard started
|
||||
// Start webserver
|
||||
tracing::info!("Starting Webserver");
|
||||
@ -739,6 +685,70 @@ fn main() -> Result<(), LauncherError> {
|
||||
println!("{}", line.unwrap());
|
||||
}
|
||||
});
|
||||
Ok(webserver)
|
||||
}
|
||||
|
||||
fn main() -> Result<(), LauncherError> {
|
||||
// Pattern match configuration
|
||||
let args = Args::parse();
|
||||
|
||||
if args.json_output {
|
||||
tracing_subscriber::fmt().json().init();
|
||||
} else {
|
||||
tracing_subscriber::fmt().compact().init();
|
||||
}
|
||||
|
||||
tracing::info!("{:?}", args);
|
||||
|
||||
let num_shard = find_num_shards(args.sharded, args.num_shard);
|
||||
if num_shard > 1 {
|
||||
tracing::info!("Sharding model on {num_shard} processes");
|
||||
}
|
||||
|
||||
// Signal handler
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
let r = running.clone();
|
||||
ctrlc::set_handler(move || {
|
||||
r.store(false, Ordering::SeqCst);
|
||||
})
|
||||
.expect("Error setting Ctrl-C handler");
|
||||
|
||||
// Check if model_id is a local model
|
||||
let local_path = Path::new(&args.model_id);
|
||||
let is_local_model = local_path.exists() && local_path.is_dir();
|
||||
|
||||
// Download weights for sharded models
|
||||
if !is_local_model && args.weights_cache_override.is_none() && num_shard > 1 {
|
||||
download_model(&args, running.clone())?;
|
||||
}
|
||||
|
||||
// Shared shutdown bool
|
||||
let shutdown = Arc::new(Mutex::new(false));
|
||||
// Shared shutdown channel
|
||||
// When shutting down, the main thread will wait for all senders to be dropped
|
||||
let (shutdown_sender, shutdown_receiver) = mpsc::channel();
|
||||
|
||||
// Shared channel to track shard status
|
||||
let (status_sender, status_receiver) = mpsc::channel();
|
||||
|
||||
spawn_shards(
|
||||
num_shard,
|
||||
&args,
|
||||
shutdown.clone(),
|
||||
&shutdown_receiver,
|
||||
shutdown_sender,
|
||||
&status_receiver,
|
||||
status_sender,
|
||||
running.clone(),
|
||||
)?;
|
||||
|
||||
// We might have received a termination signal
|
||||
if !running.load(Ordering::SeqCst) {
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut webserver = spawn_webserver(args, shutdown.clone(), &shutdown_receiver)?;
|
||||
|
||||
// Default exit code
|
||||
let mut exit_code = Ok(());
|
||||
|
Loading…
Reference in New Issue
Block a user