From 908acc55b8f2729d77ca7be9591782d6c0bf6378 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 11 Apr 2024 20:37:35 +0000 Subject: [PATCH] fix: decrease default batch, refactors and include index in batch --- launcher/src/main.rs | 2 +- router/src/lib.rs | 3 + router/src/main.rs | 2 +- router/src/server.rs | 268 ++++++++++++++++++++++--------------------- 4 files changed, 143 insertions(+), 132 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 5cbd9387..d904f91b 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -416,7 +416,7 @@ struct Args { env: bool, /// Control the maximum number of inputs that a client can send in a single request - #[clap(default_value = "32", long, env)] + #[clap(default_value = "4", long, env)] max_client_batch_size: usize, } diff --git a/router/src/lib.rs b/router/src/lib.rs index 2972e534..e7e1446e 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -293,6 +293,9 @@ mod prompt_serde { let value = Value::deserialize(deserializer)?; match value { Value::String(s) => Ok(vec![s]), + Value::Array(arr) if arr.is_empty() => Err(serde::de::Error::custom( + "Empty array detected. Do not use an empty array for the prompt.", + )), Value::Array(arr) => arr .iter() .map(|v| match v { diff --git a/router/src/main.rs b/router/src/main.rs index 6209f47f..b77117a1 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -78,7 +78,7 @@ struct Args { messages_api_enabled: bool, #[clap(long, env, default_value_t = false)] disable_grammar_support: bool, - #[clap(default_value = "32", long, env)] + #[clap(default_value = "4", long, env)] max_client_batch_size: usize, } diff --git a/router/src/server.rs b/router/src/server.rs index d804da2f..2ed6fc25 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -613,10 +613,10 @@ async fn completions( )); } - let mut generate_requests = Vec::new(); - for prompt in req.prompt.iter() { - // build the request passing some parameters - let generate_request = GenerateRequest { + let generate_requests: Vec = req + .prompt + .iter() + .map(|prompt| GenerateRequest { inputs: prompt.to_string(), parameters: GenerateParameters { best_of: None, @@ -638,9 +638,8 @@ async fn completions( top_n_tokens: None, grammar: None, }, - }; - generate_requests.push(generate_request); - } + }) + .collect(); let mut x_compute_type = "unknown".to_string(); let mut x_compute_characters = 0u32; @@ -767,130 +766,139 @@ async fn completions( }; let sse = Sse::new(stream).keep_alive(KeepAlive::default()); - return Ok((headers, sse).into_response()); + Ok((headers, sse).into_response()) + } else { + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + + let responses = FuturesUnordered::new(); + for (index, generate_request) in generate_requests.into_iter().enumerate() { + let infer_clone = infer.clone(); + let compute_type_clone = compute_type.clone(); + let response_future = async move { + let result = generate( + Extension(infer_clone), + Extension(compute_type_clone), + Json(generate_request), + ) + .await; + result.map(|(headers, generation)| (index, headers, generation)) + }; + responses.push(response_future); + } + let generate_responses = responses.try_collect::>().await?; + + let mut prompt_tokens = 0u32; + let mut completion_tokens = 0u32; + let mut total_tokens = 0u32; + + let mut x_compute_time = 0u32; + let mut x_total_time = 0u32; + let mut x_validation_time = 0u32; + let mut x_queue_time = 0u32; + let mut x_inference_time = 0u32; + let mut x_time_per_token = 0u32; + let mut x_prompt_tokens = 0u32; + let mut x_generated_tokens = 0u32; + + let choices = generate_responses + .into_iter() + .map(|(index, headers, Json(generation))| { + let details = generation.details.unwrap_or_default(); + if index == 0 { + x_compute_type = headers + .get("x-compute-type") + .and_then(|v| v.to_str().ok()) + .unwrap_or("unknown") + .to_string(); + } + + // accumulate headers and usage from each response + x_compute_time += headers + .get("x-compute-time") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_compute_characters += headers + .get("x-compute-characters") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_total_time += headers + .get("x-total-time") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_validation_time += headers + .get("x-validation-time") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_queue_time += headers + .get("x-queue-time") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_inference_time += headers + .get("x-inference-time") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_time_per_token += headers + .get("x-time-per-token") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_prompt_tokens += headers + .get("x-prompt-tokens") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_generated_tokens += headers + .get("x-generated-tokens") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + + prompt_tokens += details.prefill.len() as u32; + completion_tokens += details.generated_tokens; + total_tokens += details.prefill.len() as u32 + details.generated_tokens; + + CompletionComplete { + finish_reason: details.finish_reason.to_string(), + index: index as u32, + logprobs: None, + text: generation.generated_text, + } + }) + .collect::>(); + + let response = Completion { + id: "".to_string(), + object: "text_completion".to_string(), + created: current_time, + model: info.model_id.clone(), + system_fingerprint: format!( + "{}-{}", + info.version, + info.docker_label.unwrap_or("native") + ), + choices, + usage: Usage { + prompt_tokens, + completion_tokens, + total_tokens, + }, + }; + + // headers similar to `generate` but aggregated + let mut headers = HeaderMap::new(); + headers.insert("x-compute-type", x_compute_type.parse().unwrap()); + headers.insert("x-compute-characters", x_compute_characters.into()); + headers.insert("x-total-time", x_total_time.into()); + headers.insert("x-validation-time", x_validation_time.into()); + headers.insert("x-queue-time", x_queue_time.into()); + headers.insert("x-inference-time", x_inference_time.into()); + headers.insert("x-time-per-token", x_time_per_token.into()); + headers.insert("x-prompt-tokens", x_prompt_tokens.into()); + headers.insert("x-generated-tokens", x_generated_tokens.into()); + headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap()); + + Ok((headers, Json(response)).into_response()) } - - let current_time = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_else(|_| std::time::Duration::from_secs(0)) - .as_secs(); - - let responses = FuturesUnordered::new(); - for generate_request in generate_requests.into_iter() { - responses.push(generate( - Extension(infer.clone()), - Extension(compute_type.clone()), - Json(generate_request), - )); - } - - let generate_responses = responses.try_collect::>().await?; - - let mut prompt_tokens = 0u32; - let mut completion_tokens = 0u32; - let mut total_tokens = 0u32; - - let mut x_compute_time = 0u32; - let mut x_total_time = 0u32; - let mut x_validation_time = 0u32; - let mut x_queue_time = 0u32; - let mut x_inference_time = 0u32; - let mut x_time_per_token = 0u32; - let mut x_prompt_tokens = 0u32; - let mut x_generated_tokens = 0u32; - - let choices = generate_responses - .into_iter() - .enumerate() - .map(|(index, (headers, Json(generation)))| { - let details = generation.details.unwrap_or_default(); - if index == 0 { - x_compute_type = headers - .get("x-compute-type") - .and_then(|v| v.to_str().ok()) - .unwrap_or("unknown") - .to_string(); - } - - // accumulate headers and usage from each response - x_compute_time += headers - .get("x-compute-time") - .and_then(|v| v.to_str().ok()?.parse().ok()) - .unwrap_or(0); - x_compute_characters += headers - .get("x-compute-characters") - .and_then(|v| v.to_str().ok()?.parse().ok()) - .unwrap_or(0); - x_total_time += headers - .get("x-total-time") - .and_then(|v| v.to_str().ok()?.parse().ok()) - .unwrap_or(0); - x_validation_time += headers - .get("x-validation-time") - .and_then(|v| v.to_str().ok()?.parse().ok()) - .unwrap_or(0); - x_queue_time += headers - .get("x-queue-time") - .and_then(|v| v.to_str().ok()?.parse().ok()) - .unwrap_or(0); - x_inference_time += headers - .get("x-inference-time") - .and_then(|v| v.to_str().ok()?.parse().ok()) - .unwrap_or(0); - x_time_per_token += headers - .get("x-time-per-token") - .and_then(|v| v.to_str().ok()?.parse().ok()) - .unwrap_or(0); - x_prompt_tokens += headers - .get("x-prompt-tokens") - .and_then(|v| v.to_str().ok()?.parse().ok()) - .unwrap_or(0); - x_generated_tokens += headers - .get("x-generated-tokens") - .and_then(|v| v.to_str().ok()?.parse().ok()) - .unwrap_or(0); - - prompt_tokens += details.prefill.len() as u32; - completion_tokens += details.generated_tokens; - total_tokens += details.prefill.len() as u32 + details.generated_tokens; - - CompletionComplete { - finish_reason: details.finish_reason.to_string(), - index: index as u32, - logprobs: None, - text: generation.generated_text, - } - }) - .collect::>(); - - let response = Completion { - id: "".to_string(), - object: "text_completion".to_string(), - created: current_time, - model: info.model_id.clone(), - system_fingerprint: format!("{}-{}", info.version, info.docker_label.unwrap_or("native")), - choices, - usage: Usage { - prompt_tokens, - completion_tokens, - total_tokens, - }, - }; - - // headers similar to `generate` but aggregated - let mut headers = HeaderMap::new(); - headers.insert("x-compute-type", x_compute_type.parse().unwrap()); - headers.insert("x-compute-characters", x_compute_characters.into()); - headers.insert("x-total-time", x_total_time.into()); - headers.insert("x-validation-time", x_validation_time.into()); - headers.insert("x-queue-time", x_queue_time.into()); - headers.insert("x-inference-time", x_inference_time.into()); - headers.insert("x-time-per-token", x_time_per_token.into()); - headers.insert("x-prompt-tokens", x_prompt_tokens.into()); - headers.insert("x-generated-tokens", x_generated_tokens.into()); - headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap()); - - Ok((headers, Json(response)).into_response()) } /// Generate tokens