Modify the default for max_new_tokens.

This commit is contained in:
Nicolas Patry 2023-10-04 16:03:01 +02:00
parent 8ec1b87f16
commit c05cabc730
2 changed files with 22 additions and 17 deletions

View File

@ -107,8 +107,8 @@ pub(crate) struct GenerateParameters {
#[schema(default = "false", example = true)] #[schema(default = "false", example = true)]
pub do_sample: bool, pub do_sample: bool,
#[serde(default = "default_max_new_tokens")] #[serde(default = "default_max_new_tokens")]
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")] #[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "null")]
pub max_new_tokens: u32, pub max_new_tokens: Option<u32>,
#[serde(default)] #[serde(default)]
#[schema(nullable = true, default = "null", example = false)] #[schema(nullable = true, default = "null", example = false)]
pub return_full_text: Option<bool>, pub return_full_text: Option<bool>,
@ -140,8 +140,8 @@ pub(crate) struct GenerateParameters {
pub top_n_tokens: Option<u32>, pub top_n_tokens: Option<u32>,
} }
fn default_max_new_tokens() -> u32 { fn default_max_new_tokens() -> Option<u32> {
20 None
} }
fn default_parameters() -> GenerateParameters { fn default_parameters() -> GenerateParameters {

View File

@ -67,8 +67,8 @@ impl Validation {
&self, &self,
inputs: String, inputs: String,
truncate: Option<usize>, truncate: Option<usize>,
max_new_tokens: u32, max_new_tokens: Option<u32>,
) -> Result<(String, usize), ValidationError> { ) -> Result<(String, usize, u32), ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
if let Some(sender) = &self.sender { if let Some(sender) = &self.sender {
// Create response channel // Create response channel
@ -84,6 +84,11 @@ impl Validation {
let (inputs, input_length) = response_receiver.await.unwrap()?; let (inputs, input_length) = response_receiver.await.unwrap()?;
// Get total tokens // Get total tokens
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens{
max_new_tokens
}else{
self.max_total_tokens.saturating_sub(input_length) as u32
};
let total_tokens = input_length + max_new_tokens as usize; let total_tokens = input_length + max_new_tokens as usize;
// Validate MaxTotalTokens // Validate MaxTotalTokens
@ -104,7 +109,7 @@ impl Validation {
} }
metrics::histogram!("tgi_request_input_length", input_length as f64); metrics::histogram!("tgi_request_input_length", input_length as f64);
Ok((inputs, input_length)) Ok((inputs, input_length, max_new_tokens))
} }
// Return inputs without validation // Return inputs without validation
else { else {
@ -112,6 +117,11 @@ impl Validation {
// However, the inputs will be truncated by the python servers // However, the inputs will be truncated by the python servers
// We make sure that truncate + max_new_tokens <= self.max_total_tokens // We make sure that truncate + max_new_tokens <= self.max_total_tokens
let input_length = truncate.unwrap_or(self.max_input_length); let input_length = truncate.unwrap_or(self.max_input_length);
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens{
max_new_tokens
}else{
self.max_total_tokens.saturating_sub(input_length) as u32
};
// Validate MaxNewTokens // Validate MaxNewTokens
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
@ -121,7 +131,7 @@ impl Validation {
)); ));
} }
Ok((inputs, input_length)) Ok((inputs, input_length, max_new_tokens))
} }
} }
@ -200,7 +210,7 @@ impl Validation {
}) })
.unwrap_or(Ok(0))?; .unwrap_or(Ok(0))?;
if max_new_tokens == 0 { if max_new_tokens == Some(0) {
return Err(ValidationError::NegativeMaxNewTokens); return Err(ValidationError::NegativeMaxNewTokens);
} }
@ -247,7 +257,7 @@ impl Validation {
.unwrap_or(Ok(None))?; .unwrap_or(Ok(None))?;
// Validate inputs // Validate inputs
let (inputs, input_length) = self let (inputs, input_length, max_new_tokens) = self
.validate_input(request.inputs, truncate, max_new_tokens) .validate_input(request.inputs, truncate, max_new_tokens)
.await?; .await?;
@ -426,7 +436,7 @@ mod tests {
let max_new_tokens = 10; let max_new_tokens = 10;
match validation match validation
.validate_input("Hello".to_string(), None, max_new_tokens) .validate_input("Hello".to_string(), None, Some(max_new_tokens))
.await .await
{ {
Err(ValidationError::MaxNewTokens(1, 10)) => (), Err(ValidationError::MaxNewTokens(1, 10)) => (),
@ -455,7 +465,7 @@ mod tests {
let max_new_tokens = 10; let max_new_tokens = 10;
match validation match validation
.validate_input("Hello".to_string(), None, max_new_tokens) .validate_input("Hello".to_string(), None, Some(max_new_tokens))
.await .await
{ {
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (), Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
@ -534,7 +544,6 @@ mod tests {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
parameters: GenerateParameters { parameters: GenerateParameters {
top_p: Some(0.99), top_p: Some(0.99),
max_new_tokens: 1,
..default_parameters() ..default_parameters()
}, },
}) })
@ -549,7 +558,6 @@ mod tests {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
parameters: GenerateParameters { parameters: GenerateParameters {
top_p: None, top_p: None,
max_new_tokens: 1,
..default_parameters() ..default_parameters()
}, },
}) })
@ -596,7 +604,6 @@ mod tests {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
parameters: GenerateParameters { parameters: GenerateParameters {
top_n_tokens: Some(4), top_n_tokens: Some(4),
max_new_tokens: 1,
..default_parameters() ..default_parameters()
}, },
}) })
@ -608,7 +615,6 @@ mod tests {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
parameters: GenerateParameters { parameters: GenerateParameters {
top_n_tokens: Some(0), top_n_tokens: Some(0),
max_new_tokens: 1,
..default_parameters() ..default_parameters()
}, },
}) })
@ -620,7 +626,6 @@ mod tests {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
parameters: GenerateParameters { parameters: GenerateParameters {
top_n_tokens: None, top_n_tokens: None,
max_new_tokens: 1,
..default_parameters() ..default_parameters()
}, },
}) })