diff --git a/launcher/src/main.rs b/launcher/src/main.rs index f0e45141..054e546c 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -982,7 +982,20 @@ fn spawn_shards( Ok(()) } +fn compute_type(num_shard: usize) -> Option { + let output = Command::new("nvidia-smi") + .args(["--query-gpu=gpu_name", "--format=csv"]) + .output() + .ok()?; + let output = String::from_utf8(output.stdout).ok()?; + let fullname = output.split('\n').nth(1)?; + let cardname = fullname.replace(' ', "-").to_lowercase(); + let compute_type = format!("{num_shard}-{cardname}"); + Some(compute_type) +} + fn spawn_webserver( + num_shard: usize, args: Args, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, @@ -1072,6 +1085,13 @@ fn spawn_webserver( envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) }; + // Parse Compute type + if let Ok(compute_type) = env::var("COMPUTE_TYPE") { + envs.push(("COMPUTE_TYPE".into(), compute_type.into())) + } else if let Some(compute_type) = compute_type(num_shard) { + envs.push(("COMPUTE_TYPE".into(), compute_type.into())) + } + let mut webserver = match Command::new("text-generation-router") .args(router_args) .envs(envs) @@ -1265,8 +1285,8 @@ fn main() -> Result<(), LauncherError> { return Ok(()); } - let mut webserver = - spawn_webserver(args, shutdown.clone(), &shutdown_receiver).map_err(|err| { + let mut webserver = spawn_webserver(num_shard, args, shutdown.clone(), &shutdown_receiver) + .map_err(|err| { shutdown_shards(shutdown.clone(), &shutdown_receiver); err })?; diff --git a/router/src/server.rs b/router/src/server.rs index 39d1de38..52ed03df 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -67,7 +67,7 @@ async fn compat_generate( // switch on stream if req.stream { - Ok(generate_stream(infer,compute_type, Json(req.into())) + Ok(generate_stream(infer, compute_type, Json(req.into())) .await .into_response()) } else { @@ -372,7 +372,7 @@ async fn generate_stream_internal( let compute_characters = req.inputs.chars().count(); let mut headers = HeaderMap::new(); - headers.insert("x-compute-type",compute_type.parse().unwrap()); + headers.insert("x-compute-type", compute_type.parse().unwrap()); headers.insert( "x-compute-characters", compute_characters.to_string().parse().unwrap(), @@ -649,13 +649,22 @@ async fn chat_completions( ) }; - let (headers, response_stream) = - generate_stream_internal(infer, compute_type, Json(generate_request), on_message_callback).await; + let (headers, response_stream) = generate_stream_internal( + infer, + compute_type, + Json(generate_request), + on_message_callback, + ) + .await; let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) } else { - let (headers, Json(generation)) = - generate(Extension(infer), Extension(compute_type), Json(generate_request)).await?; + let (headers, Json(generation)) = generate( + Extension(infer), + Extension(compute_type), + Json(generate_request), + ) + .await?; let current_time = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) @@ -943,7 +952,8 @@ pub async fn run( Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise }; - let compute_type = ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string())); + let compute_type = + ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string())); // Combine routes and layers let app = Router::new()