diff --git a/launcher/src/main.rs b/launcher/src/main.rs index f0e45141..ff7e669a 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -982,7 +982,18 @@ fn spawn_shards( Ok(()) } +fn compute_type(num_shard: usize) -> Option{ + let output = Command::new("nvidia-smi").args(["--query-gpu=gpu_name", "--format=csv"]).output().ok()?; + let output = String::from_utf8(output.stdout).ok()?; + let fullname = output.split("\n").nth(1)?; + let cardname = fullname.replace(" ", "-").to_lowercase(); + let compute_type = format!("{num_shard}x{cardname}"); + Some(compute_type) + +} + fn spawn_webserver( + num_shard: usize, args: Args, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, @@ -1072,6 +1083,13 @@ fn spawn_webserver( envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) }; + // Parse Compute type + if let Ok(compute_type) = env::var("COMPUTE_TYPE") { + envs.push(("COMPUTE_TYPE".into(), compute_type.into())) + }else if let Some(compute_type) = compute_type(num_shard){ + envs.push(("COMPUTE_TYPE".into(), compute_type.into())) + } + let mut webserver = match Command::new("text-generation-router") .args(router_args) .envs(envs) @@ -1266,7 +1284,7 @@ fn main() -> Result<(), LauncherError> { } let mut webserver = - spawn_webserver(args, shutdown.clone(), &shutdown_receiver).map_err(|err| { + spawn_webserver(num_shard, args, shutdown.clone(), &shutdown_receiver).map_err(|err| { shutdown_shards(shutdown.clone(), &shutdown_receiver); err })?;