mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
update launcher
This commit is contained in:
parent
7a0bbf0994
commit
9034105553
@ -43,6 +43,10 @@ struct Args {
|
|||||||
#[clap(default_value = "29500", long, env)]
|
#[clap(default_value = "29500", long, env)]
|
||||||
master_port: usize,
|
master_port: usize,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
|
huggingface_hub_cache: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
weights_cache_override: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
json_output: bool,
|
json_output: bool,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
@ -63,6 +67,8 @@ fn main() -> ExitCode {
|
|||||||
shard_uds_path,
|
shard_uds_path,
|
||||||
master_addr,
|
master_addr,
|
||||||
master_port,
|
master_port,
|
||||||
|
huggingface_hub_cache,
|
||||||
|
weights_cache_override,
|
||||||
json_output,
|
json_output,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
} = Args::parse();
|
} = Args::parse();
|
||||||
@ -85,8 +91,7 @@ fn main() -> ExitCode {
|
|||||||
.expect("Error setting Ctrl-C handler");
|
.expect("Error setting Ctrl-C handler");
|
||||||
|
|
||||||
// Download weights
|
// Download weights
|
||||||
if num_shard > 1 {
|
if weights_cache_override.is_none() {
|
||||||
// Only download weights if in sharded mode
|
|
||||||
let mut download_argv = vec![
|
let mut download_argv = vec![
|
||||||
"text-generation-server".to_string(),
|
"text-generation-server".to_string(),
|
||||||
"download-weights".to_string(),
|
"download-weights".to_string(),
|
||||||
@ -95,29 +100,28 @@ fn main() -> ExitCode {
|
|||||||
"INFO".to_string(),
|
"INFO".to_string(),
|
||||||
"--json-output".to_string(),
|
"--json-output".to_string(),
|
||||||
];
|
];
|
||||||
|
if num_shard == 1 {
|
||||||
|
download_argv.push("--extension".to_string());
|
||||||
|
download_argv.push(".bin".to_string());
|
||||||
|
} else {
|
||||||
|
download_argv.push("--extension".to_string());
|
||||||
|
download_argv.push(".safetensors".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
// Model optional revision
|
// Model optional revision
|
||||||
if let Some(revision) = revision.clone() {
|
if let Some(ref revision) = revision {
|
||||||
download_argv.push("--revision".to_string());
|
download_argv.push("--revision".to_string());
|
||||||
download_argv.push(revision)
|
download_argv.push(revision.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut env = Vec::new();
|
let mut env = Vec::new();
|
||||||
|
|
||||||
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
|
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
|
||||||
// Useful when running inside a docker container
|
// Useful when running inside a docker container
|
||||||
if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") {
|
if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
|
||||||
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
||||||
};
|
};
|
||||||
|
|
||||||
// If the WEIGHTS_CACHE_OVERRIDE env var is set, pass it to the shard
|
|
||||||
// Useful when running inside a HuggingFace Inference Endpoint
|
|
||||||
if let Ok(weights_cache_override) = env::var("WEIGHTS_CACHE_OVERRIDE") {
|
|
||||||
env.push((
|
|
||||||
"WEIGHTS_CACHE_OVERRIDE".into(),
|
|
||||||
weights_cache_override.into(),
|
|
||||||
));
|
|
||||||
};
|
|
||||||
|
|
||||||
// Start process
|
// Start process
|
||||||
tracing::info!("Starting download");
|
tracing::info!("Starting download");
|
||||||
let mut download_process = match Popen::create(
|
let mut download_process = match Popen::create(
|
||||||
@ -196,6 +200,12 @@ fn main() -> ExitCode {
|
|||||||
}
|
}
|
||||||
sleep(Duration::from_millis(100));
|
sleep(Duration::from_millis(100));
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
tracing::info!(
|
||||||
|
"weights_cache_override is set to {:?}.",
|
||||||
|
weights_cache_override
|
||||||
|
);
|
||||||
|
tracing::info!("Skipping download.")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shared shutdown bool
|
// Shared shutdown bool
|
||||||
@ -213,6 +223,8 @@ fn main() -> ExitCode {
|
|||||||
let revision = revision.clone();
|
let revision = revision.clone();
|
||||||
let uds_path = shard_uds_path.clone();
|
let uds_path = shard_uds_path.clone();
|
||||||
let master_addr = master_addr.clone();
|
let master_addr = master_addr.clone();
|
||||||
|
let huggingface_hub_cache = huggingface_hub_cache.clone();
|
||||||
|
let weights_cache_override = weights_cache_override.clone();
|
||||||
let status_sender = status_sender.clone();
|
let status_sender = status_sender.clone();
|
||||||
let shutdown = shutdown.clone();
|
let shutdown = shutdown.clone();
|
||||||
let shutdown_sender = shutdown_sender.clone();
|
let shutdown_sender = shutdown_sender.clone();
|
||||||
@ -227,6 +239,8 @@ fn main() -> ExitCode {
|
|||||||
num_shard,
|
num_shard,
|
||||||
master_addr,
|
master_addr,
|
||||||
master_port,
|
master_port,
|
||||||
|
huggingface_hub_cache,
|
||||||
|
weights_cache_override,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
status_sender,
|
status_sender,
|
||||||
shutdown,
|
shutdown,
|
||||||
@ -346,7 +360,7 @@ fn main() -> ExitCode {
|
|||||||
|
|
||||||
while running.load(Ordering::SeqCst) {
|
while running.load(Ordering::SeqCst) {
|
||||||
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
|
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
|
||||||
tracing::error!("Shard {} failed:\n{}", rank, err);
|
tracing::error!("Shard {rank} failed:\n{err}");
|
||||||
exit_code = ExitCode::FAILURE;
|
exit_code = ExitCode::FAILURE;
|
||||||
break;
|
break;
|
||||||
};
|
};
|
||||||
@ -389,6 +403,8 @@ fn shard_manager(
|
|||||||
world_size: usize,
|
world_size: usize,
|
||||||
master_addr: String,
|
master_addr: String,
|
||||||
master_port: usize,
|
master_port: usize,
|
||||||
|
huggingface_hub_cache: Option<String>,
|
||||||
|
weights_cache_override: Option<String>,
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
status_sender: mpsc::Sender<ShardStatus>,
|
status_sender: mpsc::Sender<ShardStatus>,
|
||||||
shutdown: Arc<Mutex<bool>>,
|
shutdown: Arc<Mutex<bool>>,
|
||||||
@ -442,15 +458,15 @@ fn shard_manager(
|
|||||||
("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()),
|
("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()),
|
||||||
];
|
];
|
||||||
|
|
||||||
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
|
// If huggingface_hub_cache is some, pass it to the shard
|
||||||
// Useful when running inside a docker container
|
// Useful when running inside a docker container
|
||||||
if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") {
|
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
|
||||||
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
||||||
};
|
};
|
||||||
|
|
||||||
// If the WEIGHTS_CACHE_OVERRIDE env var is set, pass it to the shard
|
// If weights_cache_override is some, pass it to the shard
|
||||||
// Useful when running inside a HuggingFace Inference Endpoint
|
// Useful when running inside a HuggingFace Inference Endpoint
|
||||||
if let Ok(weights_cache_override) = env::var("WEIGHTS_CACHE_OVERRIDE") {
|
if let Some(weights_cache_override) = weights_cache_override {
|
||||||
env.push((
|
env.push((
|
||||||
"WEIGHTS_CACHE_OVERRIDE".into(),
|
"WEIGHTS_CACHE_OVERRIDE".into(),
|
||||||
weights_cache_override.into(),
|
weights_cache_override.into(),
|
||||||
@ -469,7 +485,7 @@ fn shard_manager(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Start process
|
// Start process
|
||||||
tracing::info!("Starting shard {}", rank);
|
tracing::info!("Starting shard {rank}");
|
||||||
let mut p = match Popen::create(
|
let mut p = match Popen::create(
|
||||||
&shard_argv,
|
&shard_argv,
|
||||||
PopenConfig {
|
PopenConfig {
|
||||||
@ -533,17 +549,17 @@ fn shard_manager(
|
|||||||
if *shutdown.lock().unwrap() {
|
if *shutdown.lock().unwrap() {
|
||||||
p.terminate().unwrap();
|
p.terminate().unwrap();
|
||||||
let _ = p.wait_timeout(Duration::from_secs(90));
|
let _ = p.wait_timeout(Duration::from_secs(90));
|
||||||
tracing::info!("Shard {} terminated", rank);
|
tracing::info!("Shard {rank} terminated");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shard is ready
|
// Shard is ready
|
||||||
if uds.exists() && !ready {
|
if uds.exists() && !ready {
|
||||||
tracing::info!("Shard {} ready in {:?}", rank, start_time.elapsed());
|
tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed());
|
||||||
status_sender.send(ShardStatus::Ready).unwrap();
|
status_sender.send(ShardStatus::Ready).unwrap();
|
||||||
ready = true;
|
ready = true;
|
||||||
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
|
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
|
||||||
tracing::info!("Waiting for shard {} to be ready...", rank);
|
tracing::info!("Waiting for shard {rank} to be ready...");
|
||||||
wait_time = Instant::now();
|
wait_time = Instant::now();
|
||||||
}
|
}
|
||||||
sleep(Duration::from_millis(100));
|
sleep(Duration::from_millis(100));
|
||||||
|
Loading…
Reference in New Issue
Block a user