mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Setting the compute_type at launchtime.
This commit is contained in:
parent
0424dabb01
commit
497d1518be
@ -982,7 +982,18 @@ fn spawn_shards(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn compute_type(num_shard: usize) -> Option<String>{
|
||||||
|
let output = Command::new("nvidia-smi").args(["--query-gpu=gpu_name", "--format=csv"]).output().ok()?;
|
||||||
|
let output = String::from_utf8(output.stdout).ok()?;
|
||||||
|
let fullname = output.split("\n").nth(1)?;
|
||||||
|
let cardname = fullname.replace(" ", "-").to_lowercase();
|
||||||
|
let compute_type = format!("{num_shard}x{cardname}");
|
||||||
|
Some(compute_type)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
fn spawn_webserver(
|
fn spawn_webserver(
|
||||||
|
num_shard: usize,
|
||||||
args: Args,
|
args: Args,
|
||||||
shutdown: Arc<AtomicBool>,
|
shutdown: Arc<AtomicBool>,
|
||||||
shutdown_receiver: &mpsc::Receiver<()>,
|
shutdown_receiver: &mpsc::Receiver<()>,
|
||||||
@ -1072,6 +1083,13 @@ fn spawn_webserver(
|
|||||||
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Parse Compute type
|
||||||
|
if let Ok(compute_type) = env::var("COMPUTE_TYPE") {
|
||||||
|
envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
|
||||||
|
}else if let Some(compute_type) = compute_type(num_shard){
|
||||||
|
envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
|
||||||
|
}
|
||||||
|
|
||||||
let mut webserver = match Command::new("text-generation-router")
|
let mut webserver = match Command::new("text-generation-router")
|
||||||
.args(router_args)
|
.args(router_args)
|
||||||
.envs(envs)
|
.envs(envs)
|
||||||
@ -1266,7 +1284,7 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mut webserver =
|
let mut webserver =
|
||||||
spawn_webserver(args, shutdown.clone(), &shutdown_receiver).map_err(|err| {
|
spawn_webserver(num_shard, args, shutdown.clone(), &shutdown_receiver).map_err(|err| {
|
||||||
shutdown_shards(shutdown.clone(), &shutdown_receiver);
|
shutdown_shards(shutdown.clone(), &shutdown_receiver);
|
||||||
err
|
err
|
||||||
})?;
|
})?;
|
||||||
|
Loading…
Reference in New Issue
Block a user