Setting the compute_type at launchtime.

This commit is contained in:
Nicolas Patry 2024-01-29 11:59:04 +01:00
parent 0424dabb01
commit 497d1518be

View File

@ -982,7 +982,18 @@ fn spawn_shards(
Ok(()) Ok(())
} }
fn compute_type(num_shard: usize) -> Option<String>{
let output = Command::new("nvidia-smi").args(["--query-gpu=gpu_name", "--format=csv"]).output().ok()?;
let output = String::from_utf8(output.stdout).ok()?;
let fullname = output.split("\n").nth(1)?;
let cardname = fullname.replace(" ", "-").to_lowercase();
let compute_type = format!("{num_shard}x{cardname}");
Some(compute_type)
}
fn spawn_webserver( fn spawn_webserver(
num_shard: usize,
args: Args, args: Args,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>, shutdown_receiver: &mpsc::Receiver<()>,
@ -1072,6 +1083,13 @@ fn spawn_webserver(
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
}; };
// Parse Compute type
if let Ok(compute_type) = env::var("COMPUTE_TYPE") {
envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
}else if let Some(compute_type) = compute_type(num_shard){
envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
}
let mut webserver = match Command::new("text-generation-router") let mut webserver = match Command::new("text-generation-router")
.args(router_args) .args(router_args)
.envs(envs) .envs(envs)
@ -1266,7 +1284,7 @@ fn main() -> Result<(), LauncherError> {
} }
let mut webserver = let mut webserver =
spawn_webserver(args, shutdown.clone(), &shutdown_receiver).map_err(|err| { spawn_webserver(num_shard, args, shutdown.clone(), &shutdown_receiver).map_err(|err| {
shutdown_shards(shutdown.clone(), &shutdown_receiver); shutdown_shards(shutdown.clone(), &shutdown_receiver);
err err
})?; })?;