diff --git a/router/src/lib.rs b/router/src/lib.rs index efc4d3ae..9fcc5085 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -63,7 +63,7 @@ pub(crate) struct GenerateParameters { pub stop: Vec, #[serde(default)] #[schema(default = "null", example = "null")] - pub truncate: Option, + pub truncate: Option, #[serde(default)] #[schema(default = "false", example = true)] pub watermark: bool, diff --git a/router/src/validation.rs b/router/src/validation.rs index 1a247caa..42af0169 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -6,6 +6,7 @@ use rand::Rng; use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; +use tokenizers::TruncationDirection; use tokio::sync::{mpsc, oneshot}; use tracing::{instrument, Span}; @@ -157,6 +158,7 @@ fn validate( do_sample, max_new_tokens, stop: stop_sequences, + truncate, seed, watermark, .. @@ -223,52 +225,70 @@ fn validate( return Err(EmptyInput); } - // Get the number of tokens in the input - match tokenizer.encode(request.inputs.clone(), true) { - Ok(mut encoding) => { - encoding.truncate() - - let input_length = encoding.len(); - let total_tokens = input_length + max_new_tokens as usize; - - if input_length > max_input_length { - Err(ValidationError::InputLength(max_input_length, input_length)) - } else if total_tokens > max_total_tokens { - Err(ValidationError::MaxTotalTokens( - max_total_tokens, - input_length, - max_new_tokens, - )) - } else { - // Return ValidGenerateRequest - let parameters = NextTokenChooserParameters { - temperature, - repetition_penalty, - top_k, - top_p, - typical_p, - do_sample, - seed, - watermark, - }; - let stopping_parameters = StoppingCriteriaParameters { - max_new_tokens, - stop_sequences, - }; - - metrics::histogram!("tgi_request_input_length", input_length as f64); - metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64); - - Ok(ValidGenerateRequest { - inputs: request.inputs, - input_length: input_length as u32, - parameters, - stopping_parameters, - }) + // Check if truncate is strictly positive and less than max_input_length + let truncate = truncate + .map(|value| { + if value == 0 || value > max_input_length { + return Err(ValidationError::Truncate(max_input_length, value)); } - } - Err(err) => Err(ValidationError::Tokenizer(err.to_string())), + Ok(Some(value)) + }) + .unwrap_or(Ok(None))?; + + // Get the number of tokens in the input + let mut encoding = tokenizer + .encode(request.inputs.clone(), true) + .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; + + let (inputs, input_length) = if let Some(truncate) = truncate { + // truncate encoding and decode new inputs + encoding.truncate(truncate, 0, TruncationDirection::Left); + let inputs = tokenizer + .decode(Vec::from(encoding.get_ids()), false) + .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; + (inputs, encoding.len()) + } else { + (request.inputs, encoding.len()) + }; + + if input_length > max_input_length { + return Err(ValidationError::InputLength(max_input_length, input_length)); } + + let total_tokens = input_length + max_new_tokens as usize; + if total_tokens > max_total_tokens { + return Err(ValidationError::MaxTotalTokens( + max_total_tokens, + input_length, + max_new_tokens, + )); + } + + // Return ValidGenerateRequest + let parameters = NextTokenChooserParameters { + temperature, + repetition_penalty, + top_k, + top_p, + typical_p, + do_sample, + seed, + watermark, + }; + let stopping_parameters = StoppingCriteriaParameters { + max_new_tokens, + stop_sequences, + }; + + metrics::histogram!("tgi_request_input_length", input_length as f64); + metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64); + + Ok(ValidGenerateRequest { + inputs, + input_length: input_length as u32, + parameters, + stopping_parameters, + }) } type ValidationRequest = ( @@ -295,6 +315,8 @@ pub enum ValidationError { TopP, #[error("`top_k` must be strictly positive")] TopK, + #[error("`truncate` must be strictly positive and less than {0}. Given: {1}")] + Truncate(usize, usize), #[error("`typical_p` must be > 0.0 and < 1.0")] TypicalP, #[error("`max_new_tokens` must be strictly positive")]