This commit is contained in:
OlivierDehaene 2023-04-26 19:42:49 +02:00
parent 20e0117e7c
commit 6d8d5b6d1d
4 changed files with 74 additions and 49 deletions

View File

@ -279,17 +279,21 @@ pub(crate) struct ErrorResponse {
} }
#[cfg(test)] #[cfg(test)]
mod tests{ mod tests {
use std::io::Write; use std::io::Write;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
pub(crate) async fn get_tokenizer() -> Tokenizer{ pub(crate) async fn get_tokenizer() -> Tokenizer {
if !std::path::Path::new("tokenizer.json").exists(){ if !std::path::Path::new("tokenizer.json").exists() {
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap(); let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json")
let mut file = std::fs::File::create("tokenizer.json").unwrap(); .await
.unwrap()
.bytes()
.await
.unwrap();
let mut file = std::fs::File::create("tokenizer.json").unwrap();
file.write_all(&content).unwrap(); file.write_all(&content).unwrap();
} }
Tokenizer::from_file("tokenizer.json").unwrap() Tokenizer::from_file("tokenizer.json").unwrap()
} }
} }

View File

@ -141,7 +141,6 @@ impl State {
// Get the next batch // Get the next batch
fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> { fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> {
if self.entries.is_empty() { if self.entries.is_empty() {
return None; return None;
} }

View File

@ -741,4 +741,3 @@ impl From<InferError> for Event {
.unwrap() .unwrap()
} }
} }

View File

@ -440,71 +440,94 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn test_validation_best_of_sampling(){ async fn test_validation_best_of_sampling() {
let tokenizer = Some(get_tokenizer().await); let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2; let max_best_of = 2;
let max_stop_sequence = 3; let max_stop_sequence = 3;
let max_input_length = 4; let max_input_length = 4;
let max_total_tokens = 5; let max_total_tokens = 5;
let workers = 1; let workers = 1;
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); let validation = Validation::new(
match validation.validate(GenerateRequest{ workers,
inputs: "Hello".to_string(), tokenizer,
parameters: GenerateParameters{ max_best_of,
best_of: Some(2), max_stop_sequence,
do_sample: false, max_input_length,
..default_parameters() max_total_tokens,
} );
}).await{ match validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
best_of: Some(2),
do_sample: false,
..default_parameters()
},
})
.await
{
Err(ValidationError::BestOfSampling) => (), Err(ValidationError::BestOfSampling) => (),
_ => panic!("Unexpected not best of sampling") _ => panic!("Unexpected not best of sampling"),
} }
} }
#[tokio::test] #[tokio::test]
async fn test_validation_top_p(){ async fn test_validation_top_p() {
let tokenizer = Some(get_tokenizer().await); let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2; let max_best_of = 2;
let max_stop_sequence = 3; let max_stop_sequence = 3;
let max_input_length = 4; let max_input_length = 4;
let max_total_tokens = 5; let max_total_tokens = 5;
let workers = 1; let workers = 1;
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); let validation = Validation::new(
match validation.validate(GenerateRequest{ workers,
inputs: "Hello".to_string(), tokenizer,
parameters: GenerateParameters{ max_best_of,
top_p: Some(1.0), max_stop_sequence,
..default_parameters() max_input_length,
} max_total_tokens,
}).await{ );
match validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_p: Some(1.0),
..default_parameters()
},
})
.await
{
Err(ValidationError::TopP) => (), Err(ValidationError::TopP) => (),
_ => panic!("Unexpected top_p") _ => panic!("Unexpected top_p"),
} }
match validation.validate(GenerateRequest{ match validation
inputs: "Hello".to_string(), .validate(GenerateRequest {
parameters: GenerateParameters{ inputs: "Hello".to_string(),
top_p: Some(0.99), parameters: GenerateParameters {
max_new_tokens: 1, top_p: Some(0.99),
..default_parameters() max_new_tokens: 1,
} ..default_parameters()
}).await{ },
})
.await
{
Ok(_) => (), Ok(_) => (),
_ => panic!("Unexpected top_p error") _ => panic!("Unexpected top_p error"),
} }
let valid_request = validation.validate(GenerateRequest{ let valid_request = validation
inputs: "Hello".to_string(), .validate(GenerateRequest {
parameters: GenerateParameters{ inputs: "Hello".to_string(),
top_p: None, parameters: GenerateParameters {
max_new_tokens: 1, top_p: None,
..default_parameters() max_new_tokens: 1,
} ..default_parameters()
}).await.unwrap(); },
})
.await
.unwrap();
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value. // top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
assert_eq!(valid_request.parameters.top_p, 1.0); assert_eq!(valid_request.parameters.top_p, 1.0);
} }
} }