From a9ea60684b6445b2507e147c6aeed0edb0b25eb7 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 29 Jan 2024 12:30:50 +0100 Subject: [PATCH] Create the compute type at launch time (if not provided in the env). (#1505) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- launcher/src/main.rs | 24 ++++++++++++++++++++++-- router/src/server.rs | 24 +++++++++++++++++------- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index f0e45141..054e546c 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -982,7 +982,20 @@ fn spawn_shards( Ok(()) } +fn compute_type(num_shard: usize) -> Option { + let output = Command::new("nvidia-smi") + .args(["--query-gpu=gpu_name", "--format=csv"]) + .output() + .ok()?; + let output = String::from_utf8(output.stdout).ok()?; + let fullname = output.split('\n').nth(1)?; + let cardname = fullname.replace(' ', "-").to_lowercase(); + let compute_type = format!("{num_shard}-{cardname}"); + Some(compute_type) +} + fn spawn_webserver( + num_shard: usize, args: Args, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, @@ -1072,6 +1085,13 @@ fn spawn_webserver( envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) }; + // Parse Compute type + if let Ok(compute_type) = env::var("COMPUTE_TYPE") { + envs.push(("COMPUTE_TYPE".into(), compute_type.into())) + } else if let Some(compute_type) = compute_type(num_shard) { + envs.push(("COMPUTE_TYPE".into(), compute_type.into())) + } + let mut webserver = match Command::new("text-generation-router") .args(router_args) .envs(envs) @@ -1265,8 +1285,8 @@ fn main() -> Result<(), LauncherError> { return Ok(()); } - let mut webserver = - spawn_webserver(args, shutdown.clone(), &shutdown_receiver).map_err(|err| { + let mut webserver = spawn_webserver(num_shard, args, shutdown.clone(), &shutdown_receiver) + .map_err(|err| { shutdown_shards(shutdown.clone(), &shutdown_receiver); err })?; diff --git a/router/src/server.rs b/router/src/server.rs index 39d1de38..52ed03df 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -67,7 +67,7 @@ async fn compat_generate( // switch on stream if req.stream { - Ok(generate_stream(infer,compute_type, Json(req.into())) + Ok(generate_stream(infer, compute_type, Json(req.into())) .await .into_response()) } else { @@ -372,7 +372,7 @@ async fn generate_stream_internal( let compute_characters = req.inputs.chars().count(); let mut headers = HeaderMap::new(); - headers.insert("x-compute-type",compute_type.parse().unwrap()); + headers.insert("x-compute-type", compute_type.parse().unwrap()); headers.insert( "x-compute-characters", compute_characters.to_string().parse().unwrap(), @@ -649,13 +649,22 @@ async fn chat_completions( ) }; - let (headers, response_stream) = - generate_stream_internal(infer, compute_type, Json(generate_request), on_message_callback).await; + let (headers, response_stream) = generate_stream_internal( + infer, + compute_type, + Json(generate_request), + on_message_callback, + ) + .await; let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) } else { - let (headers, Json(generation)) = - generate(Extension(infer), Extension(compute_type), Json(generate_request)).await?; + let (headers, Json(generation)) = generate( + Extension(infer), + Extension(compute_type), + Json(generate_request), + ) + .await?; let current_time = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) @@ -943,7 +952,8 @@ pub async fn run( Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise }; - let compute_type = ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string())); + let compute_type = + ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string())); // Combine routes and layers let app = Router::new()