diff --git a/router/src/lib.rs b/router/src/lib.rs index 76e70bb7..6f3a130b 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -107,8 +107,8 @@ pub(crate) struct GenerateParameters { #[schema(default = "false", example = true)] pub do_sample: bool, #[serde(default = "default_max_new_tokens")] - #[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")] - pub max_new_tokens: u32, + #[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "null")] + pub max_new_tokens: Option, #[serde(default)] #[schema(nullable = true, default = "null", example = false)] pub return_full_text: Option, @@ -140,8 +140,8 @@ pub(crate) struct GenerateParameters { pub top_n_tokens: Option, } -fn default_max_new_tokens() -> u32 { - 20 +fn default_max_new_tokens() -> Option { + None } fn default_parameters() -> GenerateParameters { diff --git a/router/src/validation.rs b/router/src/validation.rs index 36cbfb9b..96d8b6d2 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -67,8 +67,8 @@ impl Validation { &self, inputs: String, truncate: Option, - max_new_tokens: u32, - ) -> Result<(String, usize), ValidationError> { + max_new_tokens: Option, + ) -> Result<(String, usize, u32), ValidationError> { // If we have a fast tokenizer if let Some(sender) = &self.sender { // Create response channel @@ -84,6 +84,11 @@ impl Validation { let (inputs, input_length) = response_receiver.await.unwrap()?; // Get total tokens + let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens{ + max_new_tokens + }else{ + self.max_total_tokens.saturating_sub(input_length) as u32 + }; let total_tokens = input_length + max_new_tokens as usize; // Validate MaxTotalTokens @@ -104,7 +109,7 @@ impl Validation { } metrics::histogram!("tgi_request_input_length", input_length as f64); - Ok((inputs, input_length)) + Ok((inputs, input_length, max_new_tokens)) } // Return inputs without validation else { @@ -112,6 +117,11 @@ impl Validation { // However, the inputs will be truncated by the python servers // We make sure that truncate + max_new_tokens <= self.max_total_tokens let input_length = truncate.unwrap_or(self.max_input_length); + let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens{ + max_new_tokens + }else{ + self.max_total_tokens.saturating_sub(input_length) as u32 + }; // Validate MaxNewTokens if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { @@ -121,7 +131,7 @@ impl Validation { )); } - Ok((inputs, input_length)) + Ok((inputs, input_length, max_new_tokens)) } } @@ -200,7 +210,7 @@ impl Validation { }) .unwrap_or(Ok(0))?; - if max_new_tokens == 0 { + if max_new_tokens == Some(0) { return Err(ValidationError::NegativeMaxNewTokens); } @@ -247,7 +257,7 @@ impl Validation { .unwrap_or(Ok(None))?; // Validate inputs - let (inputs, input_length) = self + let (inputs, input_length, max_new_tokens) = self .validate_input(request.inputs, truncate, max_new_tokens) .await?; @@ -426,7 +436,7 @@ mod tests { let max_new_tokens = 10; match validation - .validate_input("Hello".to_string(), None, max_new_tokens) + .validate_input("Hello".to_string(), None, Some(max_new_tokens)) .await { Err(ValidationError::MaxNewTokens(1, 10)) => (), @@ -455,7 +465,7 @@ mod tests { let max_new_tokens = 10; match validation - .validate_input("Hello".to_string(), None, max_new_tokens) + .validate_input("Hello".to_string(), None, Some(max_new_tokens)) .await { Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (), @@ -534,7 +544,6 @@ mod tests { inputs: "Hello".to_string(), parameters: GenerateParameters { top_p: Some(0.99), - max_new_tokens: 1, ..default_parameters() }, }) @@ -549,7 +558,6 @@ mod tests { inputs: "Hello".to_string(), parameters: GenerateParameters { top_p: None, - max_new_tokens: 1, ..default_parameters() }, }) @@ -596,7 +604,6 @@ mod tests { inputs: "Hello".to_string(), parameters: GenerateParameters { top_n_tokens: Some(4), - max_new_tokens: 1, ..default_parameters() }, }) @@ -608,7 +615,6 @@ mod tests { inputs: "Hello".to_string(), parameters: GenerateParameters { top_n_tokens: Some(0), - max_new_tokens: 1, ..default_parameters() }, }) @@ -620,7 +626,6 @@ mod tests { inputs: "Hello".to_string(), parameters: GenerateParameters { top_n_tokens: None, - max_new_tokens: 1, ..default_parameters() }, })