From 57606c447b1ee72156b3bffd5531258b71b4fabe Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 10 Apr 2024 04:07:03 +0000 Subject: [PATCH] feat: handle batch completions requests --- docs/source/basic_tutorials/launcher.md | 9 + .../models/test_completion_prompts.py | 56 +++ launcher/src/main.rs | 6 + router/src/lib.rs | 44 ++- router/src/main.rs | 4 + router/src/server.rs | 372 ++++++++++++------ 6 files changed, 346 insertions(+), 145 deletions(-) create mode 100644 integration-tests/models/test_completion_prompts.py diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index d9b272db..687257bb 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -398,6 +398,15 @@ Options: -e, --env Display a lot of information about your runtime environment +``` +## MAX_CLIENT_BATCH_SIZE +```shell + --max-client-batch-size + Control the maximum number of inputs that a client can send in a single request + + [env: MAX_CLIENT_BATCH_SIZE=] + [default: 32] + ``` ## HELP ```shell diff --git a/integration-tests/models/test_completion_prompts.py b/integration-tests/models/test_completion_prompts.py new file mode 100644 index 00000000..d14c98f0 --- /dev/null +++ b/integration-tests/models/test_completion_prompts.py @@ -0,0 +1,56 @@ +import pytest +import requests + + +@pytest.fixture(scope="module") +def flash_llama_completion_handle(launcher): + with launcher( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_completion(flash_llama_completion_handle): + await flash_llama_completion_handle.health(300) + return flash_llama_completion_handle.client + + +# NOTE: since `v1/completions` is a deprecated inferface/endpoint we do not provide a convience +# method for it. Instead, we use the `requests` library to make the HTTP request directly. + + +def test_flash_llama_grammar_single_prompt(flash_llama_completion, response_snapshot): + response = requests.post( + f"{flash_llama_completion.base_url}/v1/completions", + json={ + "model": "tgi", + "prompt": "Say this is a test", + "max_tokens": 5, + "seed": 0, + }, + headers=flash_llama_completion.headers, + stream=False, + ) + response = response.json() + assert len(response["choices"]) == 1 + + +def test_flash_llama_grammar_many_prompts(flash_llama_completion, response_snapshot): + response = requests.post( + f"{flash_llama_completion.base_url}/v1/completions", + json={ + "model": "tgi", + "prompt": ["Say", "this", "is", "a", "test"], + "max_tokens": 5, + "seed": 0, + }, + headers=flash_llama_completion.headers, + stream=False, + ) + response = response.json() + assert len(response["choices"]) == 5 + + all_indexes = [choice["index"] for choice in response["choices"]] + all_indexes.sort() + assert all_indexes == [0, 1, 2, 3, 4] diff --git a/launcher/src/main.rs b/launcher/src/main.rs index be2426ee..5cbd9387 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -414,6 +414,10 @@ struct Args { /// Display a lot of information about your runtime environment #[clap(long, short, action)] env: bool, + + /// Control the maximum number of inputs that a client can send in a single request + #[clap(default_value = "32", long, env)] + max_client_batch_size: usize, } #[derive(Debug)] @@ -1078,6 +1082,8 @@ fn spawn_webserver( // Start webserver tracing::info!("Starting Webserver"); let mut router_args = vec![ + "--max-client-batch-size".to_string(), + args.max_client_batch_size.to_string(), "--max-concurrent-requests".to_string(), args.max_concurrent_requests.to_string(), "--max-best-of".to_string(), diff --git a/router/src/lib.rs b/router/src/lib.rs index 632b2cac..2972e534 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -155,6 +155,8 @@ pub struct Info { pub max_batch_size: Option, #[schema(example = "2")] pub validation_workers: usize, + #[schema(example = "32")] + pub max_client_batch_size: usize, /// Router Info #[schema(example = "0.5.0")] pub version: &'static str, @@ -284,28 +286,20 @@ mod prompt_serde { use serde::{self, Deserialize, Deserializer}; use serde_json::Value; - pub fn deserialize<'de, D>(deserializer: D) -> Result + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, { let value = Value::deserialize(deserializer)?; match value { - Value::String(s) => Ok(s), - Value::Array(arr) => { - if arr.len() == 1 { - match arr[0].as_str() { - Some(s) => Ok(s.to_string()), - None => Err(serde::de::Error::custom( - "Array contains non-string elements", - )), - } - } else { - Err(serde::de::Error::custom( - "Array contains non-string element. Expected string. In general arrays should not be used for prompts. Please use a string instead if possible.", - )) - } - } - + Value::String(s) => Ok(vec![s]), + Value::Array(arr) => arr + .iter() + .map(|v| match v { + Value::String(s) => Ok(s.to_owned()), + _ => Err(serde::de::Error::custom("Expected a string")), + }) + .collect(), _ => Err(serde::de::Error::custom( "Expected a string or an array of strings", )), @@ -323,7 +317,7 @@ pub struct CompletionRequest { /// The prompt to generate completions for. #[schema(example = "What is Deep Learning?")] #[serde(deserialize_with = "prompt_serde::deserialize")] - pub prompt: String, + pub prompt: Vec, /// The maximum number of tokens that can be generated in the chat completion. #[serde(default)] @@ -962,6 +956,20 @@ pub(crate) struct Details { pub top_tokens: Vec>, } +impl Default for Details { + fn default() -> Self { + Self { + finish_reason: FinishReason::Length, + generated_tokens: 0, + seed: None, + prefill: Vec::new(), + tokens: Vec::new(), + best_of_sequences: None, + top_tokens: Vec::new(), + } + } +} + #[derive(Serialize, ToSchema)] pub(crate) struct GenerateResponse { #[schema(example = "test")] diff --git a/router/src/main.rs b/router/src/main.rs index f3a6c46f..6209f47f 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -78,6 +78,8 @@ struct Args { messages_api_enabled: bool, #[clap(long, env, default_value_t = false)] disable_grammar_support: bool, + #[clap(default_value = "32", long, env)] + max_client_batch_size: usize, } #[tokio::main] @@ -112,6 +114,7 @@ async fn main() -> Result<(), RouterError> { ngrok_edge, messages_api_enabled, disable_grammar_support, + max_client_batch_size, } = args; // Launch Tokio runtime @@ -393,6 +396,7 @@ async fn main() -> Result<(), RouterError> { tokenizer_config, messages_api_enabled, disable_grammar_support, + max_client_batch_size, ) .await?; Ok(()) diff --git a/router/src/server.rs b/router/src/server.rs index 3f033a9d..51851dfd 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -548,11 +548,7 @@ async fn generate_stream_internal( path = "/v1/completions", request_body = CompletionRequest, responses( - (status = 200, description = "Generated Chat Completion", - content( - ("application/json" = Completion), - ("text/event-stream" = CompletionCompleteChunk), - )), + (status = 200, description = "Generated Text", body = ChatCompletionChunk), (status = 424, description = "Generation Error", body = ErrorResponse, example = json ! ({"error": "Request failed during generation"})), (status = 429, description = "Model is overloaded", body = ErrorResponse, @@ -600,126 +596,250 @@ async fn completions( )); } - // build the request passing some parameters - let generate_request = GenerateRequest { - inputs: req.prompt.to_string(), - parameters: GenerateParameters { - best_of: None, - temperature: req.temperature, - repetition_penalty: req.repetition_penalty, - frequency_penalty: req.frequency_penalty, - top_k: None, - top_p: req.top_p, - typical_p: None, - do_sample: true, - max_new_tokens, - return_full_text: None, - stop: Vec::new(), - truncate: None, - watermark: false, - details: true, - decoder_input_details: !stream, - seed, - top_n_tokens: None, - grammar: None, + if req.prompt.len() > info.max_client_batch_size { + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + return Err(( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: format!( + "Number of prompts exceeds the maximum allowed batch size of {}", + info.max_client_batch_size + ), + error_type: "batch size exceeded".to_string(), + }), + )); + } + + let mut generate_requests = Vec::new(); + for prompt in req.prompt.iter() { + // build the request passing some parameters + let generate_request = GenerateRequest { + inputs: prompt.to_string(), + parameters: GenerateParameters { + best_of: None, + temperature: req.temperature, + repetition_penalty: req.repetition_penalty, + frequency_penalty: req.frequency_penalty, + top_k: None, + top_p: req.top_p, + typical_p: None, + do_sample: true, + max_new_tokens, + return_full_text: None, + stop: Vec::new(), + truncate: None, + watermark: false, + details: true, + decoder_input_details: !stream, + seed, + top_n_tokens: None, + grammar: None, + }, + }; + generate_requests.push(generate_request); + } + + if stream { + let response_streams = FuturesUnordered::new(); + for (index, generate_request) in generate_requests.into_iter().enumerate() { + let model_id = info.model_id.clone(); + let system_fingerprint = + format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); + let on_message_callback = move |stream_token: StreamResponse| { + let event = Event::default(); + + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + + event + .json_data(CompletionCompleteChunk { + id: "".to_string(), + object: "text_completion".to_string(), + created: current_time, + + choices: vec![CompletionComplete { + finish_reason: "".to_string(), + index: index as u32, + logprobs: None, + text: stream_token.token.text, + }], + + model: model_id.clone(), + system_fingerprint: system_fingerprint.clone(), + }) + .map_or_else( + |e| { + println!("Failed to serialize ChatCompletionChunk: {:?}", e); + Event::default() + }, + |data| data, + ) + }; + + let (headers, response_stream) = generate_stream_internal( + infer.clone(), + compute_type.clone(), + Json(generate_request), + on_message_callback, + ) + .await; + + response_streams.push((headers, Box::pin(response_stream))); + } + + let stream = async_stream::stream! { + for response_stream in response_streams { + let (_headers, mut inner_stream) = response_stream; + while let Some(event) = inner_stream.next().await { + yield event; + } + } + }; + + let sse = Sse::new(stream).keep_alive(KeepAlive::default()); + return Ok((HeaderMap::new(), sse).into_response()); + } + + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + + let responses = FuturesUnordered::new(); + for generate_request in generate_requests.into_iter() { + responses.push(generate( + Extension(infer.clone()), + Extension(compute_type.clone()), + Json(generate_request), + )); + } + + let generate_responses = responses.try_collect::>().await?; + + let mut prompt_tokens = 0u32; + let mut completion_tokens = 0u32; + let mut total_tokens = 0u32; + + let mut headers = HeaderMap::new(); + + let mut x_compute_type: Option = None; + let mut x_compute_time = 0u32; + let mut x_compute_characters = 0u32; + let mut x_total_time = 0u32; + let mut x_validation_time = 0u32; + let mut x_queue_time = 0u32; + let mut x_inference_time = 0u32; + let mut x_time_per_token = 0u32; + let mut x_prompt_tokens = 0u32; + let mut x_generated_tokens = 0u32; + + // helper closure to extract a header value or default to 0 + let extract_or_zero = |headers: &HeaderMap, key: &str| { + headers + .get(key) + .and_then(|v| v.to_str().ok()) + .unwrap_or("0") + .parse::() + .unwrap_or(0) + }; + + let choices = generate_responses + .into_iter() + .enumerate() + .map(|(index, (headers, Json(generation)))| { + let details = generation.details.unwrap_or_default(); + if x_compute_type.is_none() { + x_compute_type = Some( + headers + .get("x-compute-type") + .unwrap() + .to_str() + .unwrap() + .to_string(), + ); + } + + // update headers + x_compute_time += extract_or_zero(&headers, "x-compute-time"); + x_compute_characters += extract_or_zero(&headers, "x-compute-characters"); + x_total_time += extract_or_zero(&headers, "x-total-time"); + x_validation_time += extract_or_zero(&headers, "x-validation-time"); + x_queue_time += extract_or_zero(&headers, "x-queue-time"); + x_inference_time += extract_or_zero(&headers, "x-inference-time"); + x_time_per_token += extract_or_zero(&headers, "x-time-per-token"); + x_prompt_tokens += extract_or_zero(&headers, "x-prompt-tokens"); + x_generated_tokens += extract_or_zero(&headers, "x-generated-tokens"); + + // update usage + prompt_tokens += details.prefill.len() as u32; + completion_tokens += details.generated_tokens; + total_tokens += details.prefill.len() as u32 + details.generated_tokens; + + CompletionComplete { + finish_reason: details.finish_reason.to_string(), + index: index as u32, + logprobs: None, + text: generation.generated_text, + } + }) + .collect::>(); + + // Headers similar to `generate` but aggregated + headers.insert( + "x-compute-type", + x_compute_type + .unwrap_or("unknown".to_string()) + .parse() + .unwrap(), + ); + headers.insert( + "x-compute-time", + x_compute_time.to_string().parse().unwrap(), + ); + headers.insert( + "x-compute-characters", + x_compute_characters.to_string().parse().unwrap(), + ); + headers.insert("x-total-time", x_total_time.to_string().parse().unwrap()); + headers.insert( + "x-validation-time", + x_validation_time.to_string().parse().unwrap(), + ); + headers.insert("x-queue-time", x_queue_time.to_string().parse().unwrap()); + headers.insert( + "x-inference-time", + x_inference_time.to_string().parse().unwrap(), + ); + headers.insert( + "x-time-per-token", + x_time_per_token.to_string().parse().unwrap(), + ); + headers.insert( + "x-prompt-tokens", + x_prompt_tokens.to_string().parse().unwrap(), + ); + headers.insert( + "x-generated-tokens", + x_generated_tokens.to_string().parse().unwrap(), + ); + + let response = Completion { + id: "".to_string(), + object: "text_completion".to_string(), + created: current_time, + model: info.model_id.clone(), + system_fingerprint: format!("{}-{}", info.version, info.docker_label.unwrap_or("native")), + choices, + usage: Usage { + prompt_tokens, + completion_tokens, + total_tokens, }, }; - if stream { - let on_message_callback = move |stream_token: StreamResponse| { - let event = Event::default(); - - let current_time = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_else(|_| std::time::Duration::from_secs(0)) - .as_secs(); - - event - .json_data(CompletionCompleteChunk { - id: "".to_string(), - object: "text_completion".to_string(), - created: current_time, - - choices: vec![CompletionComplete { - finish_reason: "".to_string(), - index: 0, - logprobs: None, - text: stream_token.token.text, - }], - - model: info.model_id.clone(), - system_fingerprint: format!( - "{}-{}", - info.version, - info.docker_label.unwrap_or("native") - ), - }) - .map_or_else( - |e| { - println!("Failed to serialize CompletionCompleteChunk: {:?}", e); - Event::default() - }, - |data| data, - ) - }; - - let (headers, response_stream) = generate_stream_internal( - infer, - compute_type, - Json(generate_request), - on_message_callback, - ) - .await; - - let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); - Ok((headers, sse).into_response()) - } else { - let (headers, Json(generation)) = generate( - Extension(infer), - Extension(compute_type), - Json(generate_request), - ) - .await?; - - let current_time = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_else(|_| std::time::Duration::from_secs(0)) - .as_secs(); - - let details = generation.details.ok_or(( - // this should never happen but handle if details are missing unexpectedly - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: "No details in generation".to_string(), - error_type: "no details".to_string(), - }), - ))?; - - let response = Completion { - id: "".to_string(), - object: "text_completion".to_string(), - created: current_time, - model: info.model_id.clone(), - system_fingerprint: format!( - "{}-{}", - info.version, - info.docker_label.unwrap_or("native") - ), - choices: vec![CompletionComplete { - finish_reason: details.finish_reason.to_string(), - index: 0, - logprobs: None, - text: generation.generated_text, - }], - usage: Usage { - prompt_tokens: details.prefill.len() as u32, - completion_tokens: details.generated_tokens, - total_tokens: details.prefill.len() as u32 + details.generated_tokens, - }, - }; - - Ok((headers, Json(response)).into_response()) - } + Ok((headers, Json(response)).into_response()) } /// Generate tokens @@ -729,11 +849,7 @@ async fn completions( path = "/v1/chat/completions", request_body = ChatRequest, responses( - (status = 200, description = "Generated Chat Completion", - content( - ("application/json" = ChatCompletion), - ("text/event-stream" = ChatCompletionChunk), - )), + (status = 200, description = "Generated Text", body = ChatCompletionChunk), (status = 424, description = "Generation Error", body = ErrorResponse, example = json ! ({"error": "Request failed during generation"})), (status = 429, description = "Model is overloaded", body = ErrorResponse, @@ -1152,6 +1268,7 @@ pub async fn run( tokenizer_config: HubTokenizerConfig, messages_api_enabled: bool, grammar_support: bool, + max_client_batch_size: usize, ) -> Result<(), axum::BoxError> { // OpenAPI documentation #[derive(OpenApi)] @@ -1326,6 +1443,7 @@ pub async fn run( max_waiting_tokens, max_batch_size, validation_workers, + max_client_batch_size, version: env!("CARGO_PKG_VERSION"), sha: option_env!("VERGEN_GIT_SHA"), docker_label: option_env!("DOCKER_LABEL"),