mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
fix validation error
This commit is contained in:
parent
82464709d3
commit
4267378b1f
@ -2,7 +2,6 @@ use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}
|
|||||||
/// Payload validation logic
|
/// Payload validation logic
|
||||||
use crate::{GenerateParameters, GenerateRequest};
|
use crate::{GenerateParameters, GenerateRequest};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use std::cmp::max;
|
|
||||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
@ -115,7 +114,9 @@ impl Validation {
|
|||||||
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
|
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
|
||||||
|
|
||||||
// Validate MaxNewTokens
|
// Validate MaxNewTokens
|
||||||
if (truncate + max_new_tokens) > self.max_total_tokens {
|
if (truncate.unwrap_or(self.max_input_length) as u32 + max_new_tokens)
|
||||||
|
> self.max_total_tokens as u32
|
||||||
|
{
|
||||||
return Err(ValidationError::MaxNewTokens(
|
return Err(ValidationError::MaxNewTokens(
|
||||||
self.max_total_tokens - self.max_input_length,
|
self.max_total_tokens - self.max_input_length,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
|
Loading…
Reference in New Issue
Block a user