From 20ee71dcf55098bbef617e63ba1869ed6f206b48 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 11 Oct 2023 10:46:40 +0200 Subject: [PATCH] fix: force one of max_new_tokens or truncate with slow tokenizer --- router/src/validation.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/router/src/validation.rs b/router/src/validation.rs index 9adedc5b..d0ea137d 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -116,12 +116,16 @@ impl Validation { // In this case, we don't know the real length in tokens of the inputs // However, the inputs will be truncated by the python servers // We make sure that truncate + max_new_tokens <= self.max_total_tokens - let input_length = truncate.unwrap_or(self.max_input_length); let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens { max_new_tokens } else { - self.max_total_tokens.saturating_sub(input_length) as u32 + if let Some(truncate) = truncate { + self.max_total_tokens.saturating_sub(truncate) as u32 + } else { + return Err(ValidationError::UnsetMaxNewTokens) + } }; + let input_length = truncate.unwrap_or(self.max_input_length); // Validate MaxNewTokens if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { @@ -393,6 +397,8 @@ pub enum ValidationError { Truncate(usize, usize), #[error("`typical_p` must be > 0.0 and < 1.0")] TypicalP, + #[error("one of `max_new_tokens` or `truncate` must be set if a fast tokenizer is not in use")] + UnsetMaxNewTokens, #[error("`max_new_tokens` must be strictly positive")] NegativeMaxNewTokens, #[error("`max_new_tokens` must be <= {0}. Given: {1}")]