Fix input length validation (#135)

Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
Karol Damaszke 2024-05-06 09:55:58 +02:00 committed by GitHub
parent 81182bed76
commit f82da93318
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,7 +6,7 @@ use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}
use crate::{GenerateParameters, GenerateRequest, GrammarType}; use crate::{GenerateParameters, GenerateRequest, GrammarType};
use jsonschema::{Draft, JSONSchema}; use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use std::env; use std::{cmp, env};
use serde_json::Value; use serde_json::Value;
use std::io::Cursor; use std::io::Cursor;
use text_generation_client::{ use text_generation_client::{
@ -131,7 +131,10 @@ impl Validation {
let input_length = if self.skip_tokenizer_in_tgi { let input_length = if self.skip_tokenizer_in_tgi {
inputs.chars().filter(|&c| c == ',').count() + 1 inputs.chars().filter(|&c| c == ',').count() + 1
} else { } else {
cmp::max(
encoding.len(),
truncate.unwrap_or(self.max_input_length) truncate.unwrap_or(self.max_input_length)
)
}; };
// Get total tokens // Get total tokens