From 8c3669b287a1c651cb07049e67f1ce5967828167 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 6 Dec 2024 05:50:35 +0100 Subject: [PATCH] feat: auto max_new_tokens (#2803) * feat: auto max_new_tokens * update default * Fixing the tests. --------- Co-authored-by: Nicolas Patry --- backends/v2/src/queue.rs | 1 + backends/v3/src/queue.rs | 1 + docs/openapi.json | 6 ++-- router/src/infer/mod.rs | 62 ++++++++++++++++++++++++++++++++++++++-- router/src/lib.rs | 18 +++++------- router/src/server.rs | 2 +- router/src/validation.rs | 33 ++++++++++++++++----- 7 files changed, 100 insertions(+), 23 deletions(-) diff --git a/backends/v2/src/queue.rs b/backends/v2/src/queue.rs index bf52900f4..61a3eebc9 100644 --- a/backends/v2/src/queue.rs +++ b/backends/v2/src/queue.rs @@ -436,6 +436,7 @@ mod tests { stopping_parameters: ValidStoppingParameters { ignore_eos_token: false, max_new_tokens: 1, + max_total_new_tokens: 1024, stop_sequences: vec![], }, top_n_tokens: 0, diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 6662b8de1..dd27806f9 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -573,6 +573,7 @@ mod tests { stopping_parameters: ValidStoppingParameters { ignore_eos_token: false, max_new_tokens: 1, + max_total_new_tokens: 1024, stop_sequences: vec![], }, top_n_tokens: 0, diff --git a/docs/openapi.json b/docs/openapi.json index 44691e4bb..f552ee08e 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1013,6 +1013,7 @@ "type": "integer", "format": "int32", "description": "The maximum number of tokens that can be generated in the chat completion.", + "default": "1024", "example": "32", "nullable": true, "minimum": 0 @@ -1329,7 +1330,8 @@ "type": "integer", "format": "int32", "description": "The maximum number of tokens that can be generated in the chat completion.", - "default": "32", + "default": "1024", + "example": "32", "nullable": true, "minimum": 0 }, @@ -1591,7 +1593,7 @@ "type": "integer", "format": "int32", "description": "Maximum number of tokens to generate.", - "default": "100", + "default": "1024", "example": "20", "nullable": true, "minimum": 0 diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 1351b87e2..866066438 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -111,21 +111,79 @@ impl Infer { })?; // Validate request + let mut local_request = request.clone(); let valid_request = self.validation.validate(request).await.map_err(|err| { metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); err })?; + let seed = valid_request.parameters.seed; + local_request.parameters.seed = Some(seed); let input_length = valid_request.input_length; + let max_total_new_tokens = valid_request.stopping_parameters.max_total_new_tokens; let mut generation_stream = self.backend.schedule(valid_request)?; // Wrap generation stream to update the backend health if the stream contains an error let final_stream = stream! { + let mut total_generated_tokens = 0; + let mut first_start = None; + let mut first_queued = None; + let mut all_generated_text: Option = None; + while let Some(response) = generation_stream.next().await { - yield response.inspect_err(|_err| { + let response = response.inspect_err(|_err| { self.backend_health.store(false, Ordering::SeqCst); - }) + })?; + + match response { + InferStreamResponse::Prefill(_) => yield Ok(response), + InferStreamResponse::Intermediate { .. } => { + total_generated_tokens += 1; + yield Ok(response); + } + InferStreamResponse::End { token, top_tokens,generated_text, start, queued } => { + total_generated_tokens += 1; + first_start = first_start.or(Some(start)); + first_queued = first_queued.or(Some(queued)); + if let Some(v) = all_generated_text.as_mut() { + v.text.push_str(&generated_text.text); + v.generated_tokens = total_generated_tokens; + v.finish_reason = generated_text.finish_reason.clone(); + }; + + if matches!(generated_text.finish_reason, FinishReason::Length) && total_generated_tokens < max_total_new_tokens { + local_request.inputs.push_str(&generated_text.text); + all_generated_text = all_generated_text.or(Some(generated_text)); + + let valid_request = match self.validation.validate(local_request.clone()).await { + Ok(valid_request) => valid_request, + Err(err) => { + tracing::debug!("Failed to continue request: {err}"); + yield Ok(InferStreamResponse::End {token, top_tokens, generated_text: all_generated_text.unwrap(), start: first_start.unwrap(), queued: first_queued.unwrap() }); + break; + } + }; + + generation_stream = match self.backend.schedule(valid_request) { + Ok(stream) => { + tracing::debug!("Continue request"); + yield Ok(InferStreamResponse::Intermediate { token, top_tokens } ); + stream + }, + Err(err) => { + tracing::debug!("Failed to continue request: {err}"); + yield Ok(InferStreamResponse::End {token, top_tokens, generated_text: all_generated_text.unwrap(), start: first_start.unwrap(), queued: first_queued.unwrap() }); + break; + } + } + } else { + yield Ok(InferStreamResponse::End {token, top_tokens, generated_text: all_generated_text.unwrap_or(generated_text), start: first_start.unwrap(), queued: first_queued.unwrap() }); + break; + } + + } + } } }; diff --git a/router/src/lib.rs b/router/src/lib.rs index ea697c3a1..40076564a 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -332,8 +332,8 @@ pub(crate) struct GenerateParameters { pub do_sample: bool, /// Maximum number of tokens to generate. - #[serde(default = "default_max_new_tokens")] - #[schema(nullable = true, default = "100", example = "20")] + #[serde(default)] + #[schema(nullable = true, default = "1024", example = "20")] pub max_new_tokens: Option, /// Whether to prepend the prompt to the generated text @@ -392,10 +392,6 @@ pub(crate) struct GenerateParameters { pub adapter_id: Option, } -fn default_max_new_tokens() -> Option { - Some(100) -} - fn default_parameters() -> GenerateParameters { GenerateParameters { best_of: None, @@ -406,7 +402,7 @@ fn default_parameters() -> GenerateParameters { top_p: None, typical_p: None, do_sample: true, - max_new_tokens: default_max_new_tokens(), + max_new_tokens: None, return_full_text: None, stop: Vec::new(), truncate: None, @@ -464,7 +460,7 @@ pub struct CompletionRequest { /// The maximum number of tokens that can be generated in the chat completion. #[serde(default)] - #[schema(default = "32")] + #[schema(default = "1024", example = "32")] pub max_tokens: Option, /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while @@ -842,7 +838,7 @@ pub(crate) struct ChatRequest { /// The maximum number of tokens that can be generated in the chat completion. #[serde(default)] - #[schema(example = "32")] + #[schema(default = "1024", example = "32")] pub max_tokens: Option, /// UNUSED @@ -937,7 +933,7 @@ impl ChatRequest { } = self; let repetition_penalty = presence_penalty.map(|x| x + 2.0); - let max_new_tokens = max_tokens.or(Some(100)); + let max_new_tokens = max_tokens; let tool_prompt = tool_prompt .filter(|s| !s.is_empty()) .unwrap_or_else(default_tool_prompt); @@ -1328,7 +1324,7 @@ pub struct SimpleToken { stop: usize, } -#[derive(Debug, Serialize, ToSchema)] +#[derive(Debug, Serialize, ToSchema, Clone)] #[serde(rename_all(serialize = "snake_case"))] #[schema(example = "Length")] pub enum FinishReason { diff --git a/router/src/server.rs b/router/src/server.rs index f253cb633..5a76e0797 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -714,7 +714,7 @@ pub(crate) async fn completions( .. } = req; - let max_new_tokens = max_tokens.or(Some(100)); + let max_new_tokens = max_tokens; let stop = stop.unwrap_or_default(); // enable greedy only when temperature is 0 let (do_sample, temperature) = match temperature { diff --git a/router/src/validation.rs b/router/src/validation.rs index 032638ab6..8137ac58d 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,4 +1,3 @@ -/// Payload validation logic use crate::config::Config; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::{ @@ -12,6 +11,8 @@ use jsonschema::{Draft, JSONSchema}; use outlines_core::json_schema::to_regex as json_schema_to_regex; use rand::{thread_rng, Rng}; use serde_json::Value; +/// Payload validation logic +use std::cmp::min; use std::io::Cursor; use std::iter; use std::sync::Arc; @@ -21,6 +22,8 @@ use tokio::sync::oneshot; use tracing::{instrument, Span}; use {once_cell::sync::Lazy, regex::Regex}; +static DEFAULT_GENERATION_LENGTH: u32 = 1024; + /// Validation #[derive(Debug, Clone)] pub struct Validation { @@ -131,7 +134,7 @@ impl Validation { add_special_tokens: bool, truncate: Option, max_new_tokens: Option, - ) -> Result<(Vec, Option>, usize, u32), ValidationError> { + ) -> Result<(Vec, Option>, usize, u32, u32), ValidationError> { // If we have a fast tokenizer let (encoding, inputs) = self .tokenize(inputs.clone(), add_special_tokens, truncate) @@ -144,10 +147,17 @@ impl Validation { }; // Get total tokens - let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens { - max_new_tokens + let (max_new_tokens, max_total_new_tokens) = if let Some(max_new_tokens) = max_new_tokens { + (max_new_tokens, max_new_tokens) } else { - self.max_total_tokens.saturating_sub(input_length) as u32 + // Use the maximum possible number of tokens as default + // However, the system will re-queue the request everytime it completes + // `DEFAULT_GENERATION_LENGTH` tokens. + let max_new_tokens = self.max_total_tokens.saturating_sub(input_length) as u32; + ( + min(max_new_tokens, DEFAULT_GENERATION_LENGTH), + max_new_tokens, + ) }; let total_tokens = input_length + max_new_tokens as usize; @@ -172,7 +182,13 @@ impl Validation { let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned(); metrics::histogram!("tgi_request_input_length").record(input_length as f64); - Ok((inputs, Some(input_ids), input_length, max_new_tokens)) + Ok(( + inputs, + Some(input_ids), + input_length, + max_new_tokens, + max_total_new_tokens, + )) } /// Validate a payload and get the number of tokens in the input @@ -305,7 +321,7 @@ impl Validation { .unwrap_or(Ok(None))?; // Validate inputs - let (inputs, input_ids, input_length, max_new_tokens) = self + let (inputs, input_ids, input_length, max_new_tokens, max_total_new_tokens) = self .validate_input( request.inputs, request.add_special_tokens, @@ -381,6 +397,7 @@ impl Validation { }; let stopping_parameters = ValidStoppingParameters { max_new_tokens, + max_total_new_tokens, stop_sequences, ignore_eos_token: false, }; @@ -740,6 +757,8 @@ pub struct ValidParameters { pub struct ValidStoppingParameters { /// / Maximum number of generated tokens pub max_new_tokens: u32, + /// Maximum number of generated tokens before being re-queued by the system + pub max_total_new_tokens: u32, /// / Optional stopping sequences pub stop_sequences: Vec, /// / Ignore end of sequence token