fix validation error

This commit is contained in:
OlivierDehaene 2023-04-09 10:06:53 +02:00
parent 82464709d3
commit 4267378b1f

View File

@ -2,7 +2,6 @@ use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}
/// Payload validation logic /// Payload validation logic
use crate::{GenerateParameters, GenerateRequest}; use crate::{GenerateParameters, GenerateRequest};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use std::cmp::max;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
@ -115,7 +114,9 @@ impl Validation {
// We make sure that truncate + max_new_tokens <= self.max_total_tokens // We make sure that truncate + max_new_tokens <= self.max_total_tokens
// Validate MaxNewTokens // Validate MaxNewTokens
if (truncate + max_new_tokens) > self.max_total_tokens { if (truncate.unwrap_or(self.max_input_length) as u32 + max_new_tokens)
> self.max_total_tokens as u32
{
return Err(ValidationError::MaxNewTokens( return Err(ValidationError::MaxNewTokens(
self.max_total_tokens - self.max_input_length, self.max_total_tokens - self.max_input_length,
max_new_tokens, max_new_tokens,