mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
feat: auto max_new_tokens (#2803)
* feat: auto max_new_tokens * update default * Fixing the tests. --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
parent
6685e8fcda
commit
8c3669b287
@ -436,6 +436,7 @@ mod tests {
|
|||||||
stopping_parameters: ValidStoppingParameters {
|
stopping_parameters: ValidStoppingParameters {
|
||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
|
max_total_new_tokens: 1024,
|
||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
},
|
},
|
||||||
top_n_tokens: 0,
|
top_n_tokens: 0,
|
||||||
|
@ -573,6 +573,7 @@ mod tests {
|
|||||||
stopping_parameters: ValidStoppingParameters {
|
stopping_parameters: ValidStoppingParameters {
|
||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
|
max_total_new_tokens: 1024,
|
||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
},
|
},
|
||||||
top_n_tokens: 0,
|
top_n_tokens: 0,
|
||||||
|
@ -1013,6 +1013,7 @@
|
|||||||
"type": "integer",
|
"type": "integer",
|
||||||
"format": "int32",
|
"format": "int32",
|
||||||
"description": "The maximum number of tokens that can be generated in the chat completion.",
|
"description": "The maximum number of tokens that can be generated in the chat completion.",
|
||||||
|
"default": "1024",
|
||||||
"example": "32",
|
"example": "32",
|
||||||
"nullable": true,
|
"nullable": true,
|
||||||
"minimum": 0
|
"minimum": 0
|
||||||
@ -1329,7 +1330,8 @@
|
|||||||
"type": "integer",
|
"type": "integer",
|
||||||
"format": "int32",
|
"format": "int32",
|
||||||
"description": "The maximum number of tokens that can be generated in the chat completion.",
|
"description": "The maximum number of tokens that can be generated in the chat completion.",
|
||||||
"default": "32",
|
"default": "1024",
|
||||||
|
"example": "32",
|
||||||
"nullable": true,
|
"nullable": true,
|
||||||
"minimum": 0
|
"minimum": 0
|
||||||
},
|
},
|
||||||
@ -1591,7 +1593,7 @@
|
|||||||
"type": "integer",
|
"type": "integer",
|
||||||
"format": "int32",
|
"format": "int32",
|
||||||
"description": "Maximum number of tokens to generate.",
|
"description": "Maximum number of tokens to generate.",
|
||||||
"default": "100",
|
"default": "1024",
|
||||||
"example": "20",
|
"example": "20",
|
||||||
"nullable": true,
|
"nullable": true,
|
||||||
"minimum": 0
|
"minimum": 0
|
||||||
|
@ -111,21 +111,79 @@ impl Infer {
|
|||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Validate request
|
// Validate request
|
||||||
|
let mut local_request = request.clone();
|
||||||
let valid_request = self.validation.validate(request).await.map_err(|err| {
|
let valid_request = self.validation.validate(request).await.map_err(|err| {
|
||||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
err
|
err
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
let seed = valid_request.parameters.seed;
|
||||||
|
local_request.parameters.seed = Some(seed);
|
||||||
let input_length = valid_request.input_length;
|
let input_length = valid_request.input_length;
|
||||||
|
let max_total_new_tokens = valid_request.stopping_parameters.max_total_new_tokens;
|
||||||
let mut generation_stream = self.backend.schedule(valid_request)?;
|
let mut generation_stream = self.backend.schedule(valid_request)?;
|
||||||
|
|
||||||
// Wrap generation stream to update the backend health if the stream contains an error
|
// Wrap generation stream to update the backend health if the stream contains an error
|
||||||
let final_stream = stream! {
|
let final_stream = stream! {
|
||||||
|
let mut total_generated_tokens = 0;
|
||||||
|
let mut first_start = None;
|
||||||
|
let mut first_queued = None;
|
||||||
|
let mut all_generated_text: Option<GeneratedText> = None;
|
||||||
|
|
||||||
while let Some(response) = generation_stream.next().await {
|
while let Some(response) = generation_stream.next().await {
|
||||||
yield response.inspect_err(|_err| {
|
let response = response.inspect_err(|_err| {
|
||||||
self.backend_health.store(false, Ordering::SeqCst);
|
self.backend_health.store(false, Ordering::SeqCst);
|
||||||
})
|
})?;
|
||||||
|
|
||||||
|
match response {
|
||||||
|
InferStreamResponse::Prefill(_) => yield Ok(response),
|
||||||
|
InferStreamResponse::Intermediate { .. } => {
|
||||||
|
total_generated_tokens += 1;
|
||||||
|
yield Ok(response);
|
||||||
|
}
|
||||||
|
InferStreamResponse::End { token, top_tokens,generated_text, start, queued } => {
|
||||||
|
total_generated_tokens += 1;
|
||||||
|
first_start = first_start.or(Some(start));
|
||||||
|
first_queued = first_queued.or(Some(queued));
|
||||||
|
if let Some(v) = all_generated_text.as_mut() {
|
||||||
|
v.text.push_str(&generated_text.text);
|
||||||
|
v.generated_tokens = total_generated_tokens;
|
||||||
|
v.finish_reason = generated_text.finish_reason.clone();
|
||||||
|
};
|
||||||
|
|
||||||
|
if matches!(generated_text.finish_reason, FinishReason::Length) && total_generated_tokens < max_total_new_tokens {
|
||||||
|
local_request.inputs.push_str(&generated_text.text);
|
||||||
|
all_generated_text = all_generated_text.or(Some(generated_text));
|
||||||
|
|
||||||
|
let valid_request = match self.validation.validate(local_request.clone()).await {
|
||||||
|
Ok(valid_request) => valid_request,
|
||||||
|
Err(err) => {
|
||||||
|
tracing::debug!("Failed to continue request: {err}");
|
||||||
|
yield Ok(InferStreamResponse::End {token, top_tokens, generated_text: all_generated_text.unwrap(), start: first_start.unwrap(), queued: first_queued.unwrap() });
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
generation_stream = match self.backend.schedule(valid_request) {
|
||||||
|
Ok(stream) => {
|
||||||
|
tracing::debug!("Continue request");
|
||||||
|
yield Ok(InferStreamResponse::Intermediate { token, top_tokens } );
|
||||||
|
stream
|
||||||
|
},
|
||||||
|
Err(err) => {
|
||||||
|
tracing::debug!("Failed to continue request: {err}");
|
||||||
|
yield Ok(InferStreamResponse::End {token, top_tokens, generated_text: all_generated_text.unwrap(), start: first_start.unwrap(), queued: first_queued.unwrap() });
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
yield Ok(InferStreamResponse::End {token, top_tokens, generated_text: all_generated_text.unwrap_or(generated_text), start: first_start.unwrap(), queued: first_queued.unwrap() });
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -332,8 +332,8 @@ pub(crate) struct GenerateParameters {
|
|||||||
pub do_sample: bool,
|
pub do_sample: bool,
|
||||||
|
|
||||||
/// Maximum number of tokens to generate.
|
/// Maximum number of tokens to generate.
|
||||||
#[serde(default = "default_max_new_tokens")]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, default = "100", example = "20")]
|
#[schema(nullable = true, default = "1024", example = "20")]
|
||||||
pub max_new_tokens: Option<u32>,
|
pub max_new_tokens: Option<u32>,
|
||||||
|
|
||||||
/// Whether to prepend the prompt to the generated text
|
/// Whether to prepend the prompt to the generated text
|
||||||
@ -392,10 +392,6 @@ pub(crate) struct GenerateParameters {
|
|||||||
pub adapter_id: Option<String>,
|
pub adapter_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_max_new_tokens() -> Option<u32> {
|
|
||||||
Some(100)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_parameters() -> GenerateParameters {
|
fn default_parameters() -> GenerateParameters {
|
||||||
GenerateParameters {
|
GenerateParameters {
|
||||||
best_of: None,
|
best_of: None,
|
||||||
@ -406,7 +402,7 @@ fn default_parameters() -> GenerateParameters {
|
|||||||
top_p: None,
|
top_p: None,
|
||||||
typical_p: None,
|
typical_p: None,
|
||||||
do_sample: true,
|
do_sample: true,
|
||||||
max_new_tokens: default_max_new_tokens(),
|
max_new_tokens: None,
|
||||||
return_full_text: None,
|
return_full_text: None,
|
||||||
stop: Vec::new(),
|
stop: Vec::new(),
|
||||||
truncate: None,
|
truncate: None,
|
||||||
@ -464,7 +460,7 @@ pub struct CompletionRequest {
|
|||||||
|
|
||||||
/// The maximum number of tokens that can be generated in the chat completion.
|
/// The maximum number of tokens that can be generated in the chat completion.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(default = "32")]
|
#[schema(default = "1024", example = "32")]
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
|
|
||||||
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
|
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
|
||||||
@ -842,7 +838,7 @@ pub(crate) struct ChatRequest {
|
|||||||
|
|
||||||
/// The maximum number of tokens that can be generated in the chat completion.
|
/// The maximum number of tokens that can be generated in the chat completion.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(example = "32")]
|
#[schema(default = "1024", example = "32")]
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
|
|
||||||
/// UNUSED
|
/// UNUSED
|
||||||
@ -937,7 +933,7 @@ impl ChatRequest {
|
|||||||
} = self;
|
} = self;
|
||||||
|
|
||||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||||
let max_new_tokens = max_tokens.or(Some(100));
|
let max_new_tokens = max_tokens;
|
||||||
let tool_prompt = tool_prompt
|
let tool_prompt = tool_prompt
|
||||||
.filter(|s| !s.is_empty())
|
.filter(|s| !s.is_empty())
|
||||||
.unwrap_or_else(default_tool_prompt);
|
.unwrap_or_else(default_tool_prompt);
|
||||||
@ -1328,7 +1324,7 @@ pub struct SimpleToken {
|
|||||||
stop: usize,
|
stop: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, ToSchema)]
|
#[derive(Debug, Serialize, ToSchema, Clone)]
|
||||||
#[serde(rename_all(serialize = "snake_case"))]
|
#[serde(rename_all(serialize = "snake_case"))]
|
||||||
#[schema(example = "Length")]
|
#[schema(example = "Length")]
|
||||||
pub enum FinishReason {
|
pub enum FinishReason {
|
||||||
|
@ -714,7 +714,7 @@ pub(crate) async fn completions(
|
|||||||
..
|
..
|
||||||
} = req;
|
} = req;
|
||||||
|
|
||||||
let max_new_tokens = max_tokens.or(Some(100));
|
let max_new_tokens = max_tokens;
|
||||||
let stop = stop.unwrap_or_default();
|
let stop = stop.unwrap_or_default();
|
||||||
// enable greedy only when temperature is 0
|
// enable greedy only when temperature is 0
|
||||||
let (do_sample, temperature) = match temperature {
|
let (do_sample, temperature) = match temperature {
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
/// Payload validation logic
|
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||||
use crate::{
|
use crate::{
|
||||||
@ -12,6 +11,8 @@ use jsonschema::{Draft, JSONSchema};
|
|||||||
use outlines_core::json_schema::to_regex as json_schema_to_regex;
|
use outlines_core::json_schema::to_regex as json_schema_to_regex;
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
/// Payload validation logic
|
||||||
|
use std::cmp::min;
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
use std::iter;
|
use std::iter;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@ -21,6 +22,8 @@ use tokio::sync::oneshot;
|
|||||||
use tracing::{instrument, Span};
|
use tracing::{instrument, Span};
|
||||||
use {once_cell::sync::Lazy, regex::Regex};
|
use {once_cell::sync::Lazy, regex::Regex};
|
||||||
|
|
||||||
|
static DEFAULT_GENERATION_LENGTH: u32 = 1024;
|
||||||
|
|
||||||
/// Validation
|
/// Validation
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Validation {
|
pub struct Validation {
|
||||||
@ -131,7 +134,7 @@ impl Validation {
|
|||||||
add_special_tokens: bool,
|
add_special_tokens: bool,
|
||||||
truncate: Option<usize>,
|
truncate: Option<usize>,
|
||||||
max_new_tokens: Option<u32>,
|
max_new_tokens: Option<u32>,
|
||||||
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
|
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32, u32), ValidationError> {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
let (encoding, inputs) = self
|
let (encoding, inputs) = self
|
||||||
.tokenize(inputs.clone(), add_special_tokens, truncate)
|
.tokenize(inputs.clone(), add_special_tokens, truncate)
|
||||||
@ -144,10 +147,17 @@ impl Validation {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Get total tokens
|
// Get total tokens
|
||||||
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
|
let (max_new_tokens, max_total_new_tokens) = if let Some(max_new_tokens) = max_new_tokens {
|
||||||
max_new_tokens
|
(max_new_tokens, max_new_tokens)
|
||||||
} else {
|
} else {
|
||||||
self.max_total_tokens.saturating_sub(input_length) as u32
|
// Use the maximum possible number of tokens as default
|
||||||
|
// However, the system will re-queue the request everytime it completes
|
||||||
|
// `DEFAULT_GENERATION_LENGTH` tokens.
|
||||||
|
let max_new_tokens = self.max_total_tokens.saturating_sub(input_length) as u32;
|
||||||
|
(
|
||||||
|
min(max_new_tokens, DEFAULT_GENERATION_LENGTH),
|
||||||
|
max_new_tokens,
|
||||||
|
)
|
||||||
};
|
};
|
||||||
let total_tokens = input_length + max_new_tokens as usize;
|
let total_tokens = input_length + max_new_tokens as usize;
|
||||||
|
|
||||||
@ -172,7 +182,13 @@ impl Validation {
|
|||||||
let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned();
|
let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned();
|
||||||
|
|
||||||
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
|
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
|
||||||
Ok((inputs, Some(input_ids), input_length, max_new_tokens))
|
Ok((
|
||||||
|
inputs,
|
||||||
|
Some(input_ids),
|
||||||
|
input_length,
|
||||||
|
max_new_tokens,
|
||||||
|
max_total_new_tokens,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validate a payload and get the number of tokens in the input
|
/// Validate a payload and get the number of tokens in the input
|
||||||
@ -305,7 +321,7 @@ impl Validation {
|
|||||||
.unwrap_or(Ok(None))?;
|
.unwrap_or(Ok(None))?;
|
||||||
|
|
||||||
// Validate inputs
|
// Validate inputs
|
||||||
let (inputs, input_ids, input_length, max_new_tokens) = self
|
let (inputs, input_ids, input_length, max_new_tokens, max_total_new_tokens) = self
|
||||||
.validate_input(
|
.validate_input(
|
||||||
request.inputs,
|
request.inputs,
|
||||||
request.add_special_tokens,
|
request.add_special_tokens,
|
||||||
@ -381,6 +397,7 @@ impl Validation {
|
|||||||
};
|
};
|
||||||
let stopping_parameters = ValidStoppingParameters {
|
let stopping_parameters = ValidStoppingParameters {
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
|
max_total_new_tokens,
|
||||||
stop_sequences,
|
stop_sequences,
|
||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
};
|
};
|
||||||
@ -740,6 +757,8 @@ pub struct ValidParameters {
|
|||||||
pub struct ValidStoppingParameters {
|
pub struct ValidStoppingParameters {
|
||||||
/// / Maximum number of generated tokens
|
/// / Maximum number of generated tokens
|
||||||
pub max_new_tokens: u32,
|
pub max_new_tokens: u32,
|
||||||
|
/// Maximum number of generated tokens before being re-queued by the system
|
||||||
|
pub max_total_new_tokens: u32,
|
||||||
/// / Optional stopping sequences
|
/// / Optional stopping sequences
|
||||||
pub stop_sequences: Vec<String>,
|
pub stop_sequences: Vec<String>,
|
||||||
/// / Ignore end of sequence token
|
/// / Ignore end of sequence token
|
||||||
|
Loading…
Reference in New Issue
Block a user