diff --git a/router/src/validation.rs b/router/src/validation.rs index 268c9c7a..4a9d0c23 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -6,7 +6,7 @@ use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokio::sync::{mpsc, oneshot}; -const MAX_MAX_NEW_TOKENS: usize = 512; +const MAX_MAX_NEW_TOKENS: u32 = 512; const MAX_STOP_SEQUENCES: usize = 4; /// Validation @@ -112,7 +112,7 @@ fn validate( if request.parameters.top_k < 0 { return Err(ValidationError::TopK); } - if request.parameters.max_new_tokens as usize > MAX_MAX_NEW_TOKENS { + if request.parameters.max_new_tokens > MAX_MAX_NEW_TOKENS { return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS)); } if request.parameters.stop.len() > MAX_STOP_SEQUENCES { @@ -153,7 +153,7 @@ pub enum ValidationError { #[error("top_k must be strictly positive")] TopK, #[error("max_new_tokens must be <= {0}")] - MaxNewTokens(usize), + MaxNewTokens(u32), #[error("inputs must have less than {1} tokens. Given: {0}")] InputLength(usize, usize), #[error("stop supports up to {0} stop sequences. Given: {1}")]