From a7a98c0253bab7b4daf09fff3028c812eb8d558f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 29 Jan 2024 12:11:57 +0100 Subject: [PATCH] Fmt --- launcher/src/main.rs | 14 ++++++++------ router/src/server.rs | 24 +++++++++++++++++------- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 9f7791a7..7d041b03 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -982,14 +982,16 @@ fn spawn_shards( Ok(()) } -fn compute_type(num_shard: usize) -> Option{ - let output = Command::new("nvidia-smi").args(["--query-gpu=gpu_name", "--format=csv"]).output().ok()?; +fn compute_type(num_shard: usize) -> Option { + let output = Command::new("nvidia-smi") + .args(["--query-gpu=gpu_name", "--format=csv"]) + .output() + .ok()?; let output = String::from_utf8(output.stdout).ok()?; let fullname = output.split("\n").nth(1)?; let cardname = fullname.replace(" ", "-").to_lowercase(); let compute_type = format!("{num_shard}-{cardname}"); Some(compute_type) - } fn spawn_webserver( @@ -1086,7 +1088,7 @@ fn spawn_webserver( // Parse Compute type if let Ok(compute_type) = env::var("COMPUTE_TYPE") { envs.push(("COMPUTE_TYPE".into(), compute_type.into())) - }else if let Some(compute_type) = compute_type(num_shard){ + } else if let Some(compute_type) = compute_type(num_shard) { envs.push(("COMPUTE_TYPE".into(), compute_type.into())) } @@ -1283,8 +1285,8 @@ fn main() -> Result<(), LauncherError> { return Ok(()); } - let mut webserver = - spawn_webserver(num_shard, args, shutdown.clone(), &shutdown_receiver).map_err(|err| { + let mut webserver = spawn_webserver(num_shard, args, shutdown.clone(), &shutdown_receiver) + .map_err(|err| { shutdown_shards(shutdown.clone(), &shutdown_receiver); err })?; diff --git a/router/src/server.rs b/router/src/server.rs index 39d1de38..52ed03df 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -67,7 +67,7 @@ async fn compat_generate( // switch on stream if req.stream { - Ok(generate_stream(infer,compute_type, Json(req.into())) + Ok(generate_stream(infer, compute_type, Json(req.into())) .await .into_response()) } else { @@ -372,7 +372,7 @@ async fn generate_stream_internal( let compute_characters = req.inputs.chars().count(); let mut headers = HeaderMap::new(); - headers.insert("x-compute-type",compute_type.parse().unwrap()); + headers.insert("x-compute-type", compute_type.parse().unwrap()); headers.insert( "x-compute-characters", compute_characters.to_string().parse().unwrap(), @@ -649,13 +649,22 @@ async fn chat_completions( ) }; - let (headers, response_stream) = - generate_stream_internal(infer, compute_type, Json(generate_request), on_message_callback).await; + let (headers, response_stream) = generate_stream_internal( + infer, + compute_type, + Json(generate_request), + on_message_callback, + ) + .await; let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) } else { - let (headers, Json(generation)) = - generate(Extension(infer), Extension(compute_type), Json(generate_request)).await?; + let (headers, Json(generation)) = generate( + Extension(infer), + Extension(compute_type), + Json(generate_request), + ) + .await?; let current_time = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) @@ -943,7 +952,8 @@ pub async fn run( Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise }; - let compute_type = ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string())); + let compute_type = + ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string())); // Combine routes and layers let app = Router::new()