mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Setting the compute_type at launchtime.
This commit is contained in:
parent
0424dabb01
commit
497d1518be
@ -982,7 +982,18 @@ fn spawn_shards(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compute_type(num_shard: usize) -> Option<String>{
|
||||
let output = Command::new("nvidia-smi").args(["--query-gpu=gpu_name", "--format=csv"]).output().ok()?;
|
||||
let output = String::from_utf8(output.stdout).ok()?;
|
||||
let fullname = output.split("\n").nth(1)?;
|
||||
let cardname = fullname.replace(" ", "-").to_lowercase();
|
||||
let compute_type = format!("{num_shard}x{cardname}");
|
||||
Some(compute_type)
|
||||
|
||||
}
|
||||
|
||||
fn spawn_webserver(
|
||||
num_shard: usize,
|
||||
args: Args,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
shutdown_receiver: &mpsc::Receiver<()>,
|
||||
@ -1072,6 +1083,13 @@ fn spawn_webserver(
|
||||
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
||||
};
|
||||
|
||||
// Parse Compute type
|
||||
if let Ok(compute_type) = env::var("COMPUTE_TYPE") {
|
||||
envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
|
||||
}else if let Some(compute_type) = compute_type(num_shard){
|
||||
envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
|
||||
}
|
||||
|
||||
let mut webserver = match Command::new("text-generation-router")
|
||||
.args(router_args)
|
||||
.envs(envs)
|
||||
@ -1266,7 +1284,7 @@ fn main() -> Result<(), LauncherError> {
|
||||
}
|
||||
|
||||
let mut webserver =
|
||||
spawn_webserver(args, shutdown.clone(), &shutdown_receiver).map_err(|err| {
|
||||
spawn_webserver(num_shard, args, shutdown.clone(), &shutdown_receiver).map_err(|err| {
|
||||
shutdown_shards(shutdown.clone(), &shutdown_receiver);
|
||||
err
|
||||
})?;
|
||||
|
Loading…
Reference in New Issue
Block a user