diff --git a/router/src/validation.rs b/router/src/validation.rs index ba6f4f6d..2029c7e0 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -161,17 +161,18 @@ impl Validation { } else { return Err(ValidationError::UnsetMaxNewTokens); }; - let input_length = truncate.unwrap_or(self.max_input_length); + let mut input_length = truncate.unwrap_or(self.max_input_length); // We don't have a tokenizer, therefore we have no idea how long is the query, let // them through and hope for the best. // Validate MaxNewTokens - // if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { - // return Err(ValidationError::MaxNewTokens( - // self.max_total_tokens - self.max_input_length, - // max_new_tokens, - // )); - // } + if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { + input_length = input_length.saturating_sub(max_new_tokens as usize); + // return Err(ValidationError::MaxNewTokens( + // self.max_total_tokens - self.max_input_length, + // max_new_tokens, + // )); + } Ok((inputs, input_length, max_new_tokens)) } @@ -666,8 +667,9 @@ mod tests { .validate_input("Hello".to_string(), None, Some(max_new_tokens)) .await { - Err(ValidationError::MaxNewTokens(1, 10)) => (), - _ => panic!("Unexpected not max new tokens"), + // Err(ValidationError::MaxNewTokens(1, 10)) => (), + Ok((_s, 0, 10)) => (), + r => panic!("Unexpected not max new tokens: {r:?}"), } }