diff --git a/router/src/lib.rs b/router/src/lib.rs index d36f0d5c..c2ff669b 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -279,17 +279,21 @@ pub(crate) struct ErrorResponse { } #[cfg(test)] -mod tests{ +mod tests { use std::io::Write; use tokenizers::Tokenizer; - pub(crate) async fn get_tokenizer() -> Tokenizer{ - if !std::path::Path::new("tokenizer.json").exists(){ - let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap(); - let mut file = std::fs::File::create("tokenizer.json").unwrap(); + pub(crate) async fn get_tokenizer() -> Tokenizer { + if !std::path::Path::new("tokenizer.json").exists() { + let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json") + .await + .unwrap() + .bytes() + .await + .unwrap(); + let mut file = std::fs::File::create("tokenizer.json").unwrap(); file.write_all(&content).unwrap(); } Tokenizer::from_file("tokenizer.json").unwrap() } } - diff --git a/router/src/queue.rs b/router/src/queue.rs index d3f118d8..94851e1c 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -141,7 +141,6 @@ impl State { // Get the next batch fn next_batch(&mut self, min_size: Option, token_budget: u32) -> Option { - if self.entries.is_empty() { return None; } diff --git a/router/src/server.rs b/router/src/server.rs index 3f81edfd..1fd48963 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -741,4 +741,3 @@ impl From for Event { .unwrap() } } - diff --git a/router/src/validation.rs b/router/src/validation.rs index 65ef0038..cbb0d9cd 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -440,71 +440,94 @@ mod tests { } #[tokio::test] - async fn test_validation_best_of_sampling(){ + async fn test_validation_best_of_sampling() { let tokenizer = Some(get_tokenizer().await); let max_best_of = 2; let max_stop_sequence = 3; let max_input_length = 4; let max_total_tokens = 5; let workers = 1; - let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); - match validation.validate(GenerateRequest{ - inputs: "Hello".to_string(), - parameters: GenerateParameters{ - best_of: Some(2), - do_sample: false, - ..default_parameters() - } - }).await{ + let validation = Validation::new( + workers, + tokenizer, + max_best_of, + max_stop_sequence, + max_input_length, + max_total_tokens, + ); + match validation + .validate(GenerateRequest { + inputs: "Hello".to_string(), + parameters: GenerateParameters { + best_of: Some(2), + do_sample: false, + ..default_parameters() + }, + }) + .await + { Err(ValidationError::BestOfSampling) => (), - _ => panic!("Unexpected not best of sampling") + _ => panic!("Unexpected not best of sampling"), } - } #[tokio::test] - async fn test_validation_top_p(){ + async fn test_validation_top_p() { let tokenizer = Some(get_tokenizer().await); let max_best_of = 2; let max_stop_sequence = 3; let max_input_length = 4; let max_total_tokens = 5; let workers = 1; - let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); - match validation.validate(GenerateRequest{ - inputs: "Hello".to_string(), - parameters: GenerateParameters{ - top_p: Some(1.0), - ..default_parameters() - } - }).await{ + let validation = Validation::new( + workers, + tokenizer, + max_best_of, + max_stop_sequence, + max_input_length, + max_total_tokens, + ); + match validation + .validate(GenerateRequest { + inputs: "Hello".to_string(), + parameters: GenerateParameters { + top_p: Some(1.0), + ..default_parameters() + }, + }) + .await + { Err(ValidationError::TopP) => (), - _ => panic!("Unexpected top_p") + _ => panic!("Unexpected top_p"), } - match validation.validate(GenerateRequest{ - inputs: "Hello".to_string(), - parameters: GenerateParameters{ - top_p: Some(0.99), - max_new_tokens: 1, - ..default_parameters() - } - }).await{ + match validation + .validate(GenerateRequest { + inputs: "Hello".to_string(), + parameters: GenerateParameters { + top_p: Some(0.99), + max_new_tokens: 1, + ..default_parameters() + }, + }) + .await + { Ok(_) => (), - _ => panic!("Unexpected top_p error") + _ => panic!("Unexpected top_p error"), } - let valid_request = validation.validate(GenerateRequest{ - inputs: "Hello".to_string(), - parameters: GenerateParameters{ - top_p: None, - max_new_tokens: 1, - ..default_parameters() - } - }).await.unwrap(); + let valid_request = validation + .validate(GenerateRequest { + inputs: "Hello".to_string(), + parameters: GenerateParameters { + top_p: None, + max_new_tokens: 1, + ..default_parameters() + }, + }) + .await + .unwrap(); // top_p == 1.0 is invalid for users to ask for but it's the default resolved value. assert_eq!(valid_request.parameters.top_p, 1.0); - - } }