mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Modify the default for max_new_tokens
.
This commit is contained in:
parent
8ec1b87f16
commit
c05cabc730
@ -107,8 +107,8 @@ pub(crate) struct GenerateParameters {
|
||||
#[schema(default = "false", example = true)]
|
||||
pub do_sample: bool,
|
||||
#[serde(default = "default_max_new_tokens")]
|
||||
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
|
||||
pub max_new_tokens: u32,
|
||||
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "null")]
|
||||
pub max_new_tokens: Option<u32>,
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, default = "null", example = false)]
|
||||
pub return_full_text: Option<bool>,
|
||||
@ -140,8 +140,8 @@ pub(crate) struct GenerateParameters {
|
||||
pub top_n_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
fn default_max_new_tokens() -> u32 {
|
||||
20
|
||||
fn default_max_new_tokens() -> Option<u32> {
|
||||
None
|
||||
}
|
||||
|
||||
fn default_parameters() -> GenerateParameters {
|
||||
|
@ -67,8 +67,8 @@ impl Validation {
|
||||
&self,
|
||||
inputs: String,
|
||||
truncate: Option<usize>,
|
||||
max_new_tokens: u32,
|
||||
) -> Result<(String, usize), ValidationError> {
|
||||
max_new_tokens: Option<u32>,
|
||||
) -> Result<(String, usize, u32), ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
if let Some(sender) = &self.sender {
|
||||
// Create response channel
|
||||
@ -84,6 +84,11 @@ impl Validation {
|
||||
let (inputs, input_length) = response_receiver.await.unwrap()?;
|
||||
|
||||
// Get total tokens
|
||||
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens{
|
||||
max_new_tokens
|
||||
}else{
|
||||
self.max_total_tokens.saturating_sub(input_length) as u32
|
||||
};
|
||||
let total_tokens = input_length + max_new_tokens as usize;
|
||||
|
||||
// Validate MaxTotalTokens
|
||||
@ -104,7 +109,7 @@ impl Validation {
|
||||
}
|
||||
|
||||
metrics::histogram!("tgi_request_input_length", input_length as f64);
|
||||
Ok((inputs, input_length))
|
||||
Ok((inputs, input_length, max_new_tokens))
|
||||
}
|
||||
// Return inputs without validation
|
||||
else {
|
||||
@ -112,6 +117,11 @@ impl Validation {
|
||||
// However, the inputs will be truncated by the python servers
|
||||
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
|
||||
let input_length = truncate.unwrap_or(self.max_input_length);
|
||||
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens{
|
||||
max_new_tokens
|
||||
}else{
|
||||
self.max_total_tokens.saturating_sub(input_length) as u32
|
||||
};
|
||||
|
||||
// Validate MaxNewTokens
|
||||
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
|
||||
@ -121,7 +131,7 @@ impl Validation {
|
||||
));
|
||||
}
|
||||
|
||||
Ok((inputs, input_length))
|
||||
Ok((inputs, input_length, max_new_tokens))
|
||||
}
|
||||
}
|
||||
|
||||
@ -200,7 +210,7 @@ impl Validation {
|
||||
})
|
||||
.unwrap_or(Ok(0))?;
|
||||
|
||||
if max_new_tokens == 0 {
|
||||
if max_new_tokens == Some(0) {
|
||||
return Err(ValidationError::NegativeMaxNewTokens);
|
||||
}
|
||||
|
||||
@ -247,7 +257,7 @@ impl Validation {
|
||||
.unwrap_or(Ok(None))?;
|
||||
|
||||
// Validate inputs
|
||||
let (inputs, input_length) = self
|
||||
let (inputs, input_length, max_new_tokens) = self
|
||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||
.await?;
|
||||
|
||||
@ -426,7 +436,7 @@ mod tests {
|
||||
|
||||
let max_new_tokens = 10;
|
||||
match validation
|
||||
.validate_input("Hello".to_string(), None, max_new_tokens)
|
||||
.validate_input("Hello".to_string(), None, Some(max_new_tokens))
|
||||
.await
|
||||
{
|
||||
Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
||||
@ -455,7 +465,7 @@ mod tests {
|
||||
|
||||
let max_new_tokens = 10;
|
||||
match validation
|
||||
.validate_input("Hello".to_string(), None, max_new_tokens)
|
||||
.validate_input("Hello".to_string(), None, Some(max_new_tokens))
|
||||
.await
|
||||
{
|
||||
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
|
||||
@ -534,7 +544,6 @@ mod tests {
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
top_p: Some(0.99),
|
||||
max_new_tokens: 1,
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
@ -549,7 +558,6 @@ mod tests {
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
top_p: None,
|
||||
max_new_tokens: 1,
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
@ -596,7 +604,6 @@ mod tests {
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
top_n_tokens: Some(4),
|
||||
max_new_tokens: 1,
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
@ -608,7 +615,6 @@ mod tests {
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
top_n_tokens: Some(0),
|
||||
max_new_tokens: 1,
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
@ -620,7 +626,6 @@ mod tests {
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
top_n_tokens: None,
|
||||
max_new_tokens: 1,
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
|
Loading…
Reference in New Issue
Block a user