mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: auto max_new_tokens
This commit is contained in:
parent
e0db633396
commit
124eea2d0e
@ -111,21 +111,79 @@ impl Infer {
|
||||
})?;
|
||||
|
||||
// Validate request
|
||||
let mut local_request = request.clone();
|
||||
let valid_request = self.validation.validate(request).await.map_err(|err| {
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
err
|
||||
})?;
|
||||
|
||||
let seed = valid_request.parameters.seed;
|
||||
local_request.parameters.seed = Some(seed);
|
||||
let input_length = valid_request.input_length;
|
||||
let max_total_new_tokens = valid_request.stopping_parameters.max_total_new_tokens;
|
||||
let mut generation_stream = self.backend.schedule(valid_request)?;
|
||||
|
||||
// Wrap generation stream to update the backend health if the stream contains an error
|
||||
let final_stream = stream! {
|
||||
let mut total_generated_tokens = 0;
|
||||
let mut first_start = None;
|
||||
let mut first_queued = None;
|
||||
let mut all_generated_text: Option<GeneratedText> = None;
|
||||
|
||||
while let Some(response) = generation_stream.next().await {
|
||||
yield response.inspect_err(|_err| {
|
||||
let response = response.inspect_err(|_err| {
|
||||
self.backend_health.store(false, Ordering::SeqCst);
|
||||
})
|
||||
})?;
|
||||
|
||||
match response {
|
||||
InferStreamResponse::Prefill(_) => yield Ok(response),
|
||||
InferStreamResponse::Intermediate { .. } => {
|
||||
total_generated_tokens += 1;
|
||||
yield Ok(response);
|
||||
}
|
||||
InferStreamResponse::End { token, top_tokens,generated_text, start, queued } => {
|
||||
total_generated_tokens += 1;
|
||||
first_start = first_start.or(Some(start));
|
||||
first_queued = first_queued.or(Some(queued));
|
||||
if let Some(v) = all_generated_text.as_mut() {
|
||||
v.text.push_str(&generated_text.text);
|
||||
v.generated_tokens = total_generated_tokens;
|
||||
v.finish_reason = generated_text.finish_reason.clone();
|
||||
};
|
||||
|
||||
if matches!(generated_text.finish_reason, FinishReason::Length) && total_generated_tokens < max_total_new_tokens {
|
||||
local_request.inputs.push_str(&generated_text.text);
|
||||
all_generated_text = all_generated_text.or(Some(generated_text));
|
||||
|
||||
let valid_request = match self.validation.validate(local_request.clone()).await {
|
||||
Ok(valid_request) => valid_request,
|
||||
Err(err) => {
|
||||
tracing::debug!("Failed to continue request: {err}");
|
||||
yield Ok(InferStreamResponse::End {token, top_tokens, generated_text: all_generated_text.unwrap(), start: first_start.unwrap(), queued: first_queued.unwrap() });
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
generation_stream = match self.backend.schedule(valid_request) {
|
||||
Ok(stream) => {
|
||||
tracing::debug!("Continue request");
|
||||
yield Ok(InferStreamResponse::Intermediate { token, top_tokens } );
|
||||
stream
|
||||
},
|
||||
Err(err) => {
|
||||
tracing::debug!("Failed to continue request: {err}");
|
||||
yield Ok(InferStreamResponse::End {token, top_tokens, generated_text: all_generated_text.unwrap(), start: first_start.unwrap(), queued: first_queued.unwrap() });
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
yield Ok(InferStreamResponse::End {token, top_tokens, generated_text: all_generated_text.unwrap_or(generated_text), start: first_start.unwrap(), queued: first_queued.unwrap() });
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -332,8 +332,8 @@ pub(crate) struct GenerateParameters {
|
||||
pub do_sample: bool,
|
||||
|
||||
/// Maximum number of tokens to generate.
|
||||
#[serde(default = "default_max_new_tokens")]
|
||||
#[schema(nullable = true, default = "100", example = "20")]
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, default = "256", example = "20")]
|
||||
pub max_new_tokens: Option<u32>,
|
||||
|
||||
/// Whether to prepend the prompt to the generated text
|
||||
@ -392,10 +392,6 @@ pub(crate) struct GenerateParameters {
|
||||
pub adapter_id: Option<String>,
|
||||
}
|
||||
|
||||
fn default_max_new_tokens() -> Option<u32> {
|
||||
Some(100)
|
||||
}
|
||||
|
||||
fn default_parameters() -> GenerateParameters {
|
||||
GenerateParameters {
|
||||
best_of: None,
|
||||
@ -406,7 +402,7 @@ fn default_parameters() -> GenerateParameters {
|
||||
top_p: None,
|
||||
typical_p: None,
|
||||
do_sample: true,
|
||||
max_new_tokens: default_max_new_tokens(),
|
||||
max_new_tokens: None,
|
||||
return_full_text: None,
|
||||
stop: Vec::new(),
|
||||
truncate: None,
|
||||
@ -937,7 +933,7 @@ impl ChatRequest {
|
||||
} = self;
|
||||
|
||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||
let max_new_tokens = max_tokens.or(Some(100));
|
||||
let max_new_tokens = max_tokens;
|
||||
let tool_prompt = tool_prompt
|
||||
.filter(|s| !s.is_empty())
|
||||
.unwrap_or_else(default_tool_prompt);
|
||||
@ -1328,7 +1324,7 @@ pub struct SimpleToken {
|
||||
stop: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
#[derive(Debug, Serialize, ToSchema, Clone)]
|
||||
#[serde(rename_all(serialize = "snake_case"))]
|
||||
#[schema(example = "Length")]
|
||||
pub enum FinishReason {
|
||||
|
@ -714,7 +714,7 @@ pub(crate) async fn completions(
|
||||
..
|
||||
} = req;
|
||||
|
||||
let max_new_tokens = max_tokens.or(Some(100));
|
||||
let max_new_tokens = max_tokens;
|
||||
let stop = stop.unwrap_or_default();
|
||||
// enable greedy only when temperature is 0
|
||||
let (do_sample, temperature) = match temperature {
|
||||
|
@ -1,4 +1,3 @@
|
||||
/// Payload validation logic
|
||||
use crate::config::Config;
|
||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||
use crate::{
|
||||
@ -12,6 +11,8 @@ use jsonschema::{Draft, JSONSchema};
|
||||
use outlines_core::json_schema::to_regex as json_schema_to_regex;
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde_json::Value;
|
||||
/// Payload validation logic
|
||||
use std::cmp::min;
|
||||
use std::io::Cursor;
|
||||
use std::iter;
|
||||
use std::sync::Arc;
|
||||
@ -21,6 +22,8 @@ use tokio::sync::oneshot;
|
||||
use tracing::{instrument, Span};
|
||||
use {once_cell::sync::Lazy, regex::Regex};
|
||||
|
||||
static DEFAULT_GENERATION_LENGTH: u32 = 10;
|
||||
|
||||
/// Validation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Validation {
|
||||
@ -131,7 +134,7 @@ impl Validation {
|
||||
add_special_tokens: bool,
|
||||
truncate: Option<usize>,
|
||||
max_new_tokens: Option<u32>,
|
||||
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
|
||||
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32, u32), ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
let (encoding, inputs) = self
|
||||
.tokenize(inputs.clone(), add_special_tokens, truncate)
|
||||
@ -144,10 +147,17 @@ impl Validation {
|
||||
};
|
||||
|
||||
// Get total tokens
|
||||
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
|
||||
max_new_tokens
|
||||
let (max_new_tokens, max_total_new_tokens) = if let Some(max_new_tokens) = max_new_tokens {
|
||||
(max_new_tokens, max_new_tokens)
|
||||
} else {
|
||||
self.max_total_tokens.saturating_sub(input_length) as u32
|
||||
// Use the maximum possible number of tokens as default
|
||||
// However, the system will re-queue the request everytime it completes
|
||||
// `DEFAULT_GENERATION_LENGTH` tokens.
|
||||
let max_new_tokens = self.max_total_tokens.saturating_sub(input_length) as u32;
|
||||
(
|
||||
min(max_new_tokens, DEFAULT_GENERATION_LENGTH),
|
||||
max_new_tokens,
|
||||
)
|
||||
};
|
||||
let total_tokens = input_length + max_new_tokens as usize;
|
||||
|
||||
@ -172,7 +182,13 @@ impl Validation {
|
||||
let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned();
|
||||
|
||||
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
|
||||
Ok((inputs, Some(input_ids), input_length, max_new_tokens))
|
||||
Ok((
|
||||
inputs,
|
||||
Some(input_ids),
|
||||
input_length,
|
||||
max_new_tokens,
|
||||
max_total_new_tokens,
|
||||
))
|
||||
}
|
||||
|
||||
/// Validate a payload and get the number of tokens in the input
|
||||
@ -305,7 +321,7 @@ impl Validation {
|
||||
.unwrap_or(Ok(None))?;
|
||||
|
||||
// Validate inputs
|
||||
let (inputs, input_ids, input_length, max_new_tokens) = self
|
||||
let (inputs, input_ids, input_length, max_new_tokens, max_total_new_tokens) = self
|
||||
.validate_input(
|
||||
request.inputs,
|
||||
request.add_special_tokens,
|
||||
@ -381,6 +397,7 @@ impl Validation {
|
||||
};
|
||||
let stopping_parameters = ValidStoppingParameters {
|
||||
max_new_tokens,
|
||||
max_total_new_tokens,
|
||||
stop_sequences,
|
||||
ignore_eos_token: false,
|
||||
};
|
||||
@ -740,6 +757,8 @@ pub struct ValidParameters {
|
||||
pub struct ValidStoppingParameters {
|
||||
/// / Maximum number of generated tokens
|
||||
pub max_new_tokens: u32,
|
||||
/// Maximum number of generated tokens before being re-queued by the system
|
||||
pub max_total_new_tokens: u32,
|
||||
/// / Optional stopping sequences
|
||||
pub stop_sequences: Vec<String>,
|
||||
/// / Ignore end of sequence token
|
||||
|
Loading…
Reference in New Issue
Block a user