diff --git a/Cargo.lock b/Cargo.lock index e240fe9f..72c3cb2c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -183,6 +183,21 @@ dependencies = [ "tower-service", ] +[[package]] +name = "axum-test-server" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57f8f8627d32fe7e2c36b33de0e87dcdee4d6ac8619b9b892e5cc299ea4eed52" +dependencies = [ + "anyhow", + "axum", + "hyper", + "portpicker", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "axum-tracing-opentelemetry" version = "0.10.0" @@ -1736,6 +1751,15 @@ version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26f6a7b87c2e435a3241addceeeff740ff8b7e76b74c13bf9acb17fa454ea00b" +[[package]] +name = "portpicker" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be97d76faf1bfab666e1375477b23fde79eccf0276e9b63b92a39d676a889ba9" +dependencies = [ + "rand", +] + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -2407,6 +2431,7 @@ version = "0.6.0" dependencies = [ "async-stream", "axum", + "axum-test-server", "axum-tracing-opentelemetry", "clap", "flume", diff --git a/router/Cargo.toml b/router/Cargo.toml index 4fa523a5..64aa65f3 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -42,3 +42,6 @@ utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] } [build-dependencies] vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] } + +[dev-dependencies] +axum-test-server = "2.0.0" diff --git a/router/src/lib.rs b/router/src/lib.rs index 7a1707d9..85b13cfa 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -276,3 +276,19 @@ pub(crate) struct ErrorResponse { pub error: String, pub error_type: String, } + +#[cfg(test)] +mod tests{ + use std::io::Write; + use tokenizers::Tokenizer; + + pub(crate) async fn get_tokenizer() -> Tokenizer{ + if !std::path::Path::new("tokenizer.json").exists(){ + let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap(); + let mut file = std::fs::File::create("tokenizer.json").unwrap(); + file.write_all(&content).unwrap(); + } + Tokenizer::from_file("tokenizer.json").unwrap() + } +} + diff --git a/router/src/server.rs b/router/src/server.rs index 9540ba18..356e8025 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -741,3 +741,63 @@ impl From for Event { .unwrap() } } + +#[cfg(test)] +mod tests{ + use super::*; + use crate::tests::get_tokenizer; + use axum_test_server::TestServer; + use crate::default_parameters; + + #[tokio::test] + async fn test_health(){ + let tokenizer = Some(get_tokenizer().await); + let max_best_of = 2; + let max_stop_sequence = 3; + let max_input_length = 4; + let max_total_tokens = 5; + let workers = 1; + let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); + match validation.validate(GenerateRequest{ + inputs: "Hello".to_string(), + parameters: GenerateParameters{ + best_of: Some(2), + do_sample: false, + ..default_parameters() + } + }).await{ + Err(ValidationError::BestOfSampling) => (), + _ => panic!("Unexpected not best of sampling") + } + + let client = ShardedClient::connect_uds("/tmp/text-generation-test".to_string()).await.unwrap(); + let waiting_served_ratio = 1.2; + let max_batch_total_tokens = 100; + let max_waiting_tokens = 10; + let max_concurrent_requests = 10; + let requires_padding = false; + let infer = Infer::new( + client, + validation, + waiting_served_ratio, + max_batch_total_tokens, + max_waiting_tokens, + max_concurrent_requests, + requires_padding, + ); + let app = Router::new() + .route("/health", get(health)) + .layer(Extension(infer)) + .into_make_service(); + + // Run the server on a random address. + let server = TestServer::new(app); + + // Get the request. + let response = server + .get("/health") + .await; + + assert_eq!(response.contents, "pong!"); + } +} diff --git a/router/src/validation.rs b/router/src/validation.rs index 983c2612..f87d4d34 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -382,6 +382,7 @@ pub enum ValidationError { #[cfg(test)] mod tests{ use super::*; + use crate::default_parameters; use std::io::Write; #[tokio::test] @@ -426,4 +427,73 @@ mod tests{ _ => panic!("Unexpected not max new tokens") } } + + #[tokio::test] + async fn test_validation_best_of_sampling(){ + let tokenizer = Some(get_tokenizer().await); + let max_best_of = 2; + let max_stop_sequence = 3; + let max_input_length = 4; + let max_total_tokens = 5; + let workers = 1; + let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); + match validation.validate(GenerateRequest{ + inputs: "Hello".to_string(), + parameters: GenerateParameters{ + best_of: Some(2), + do_sample: false, + ..default_parameters() + } + }).await{ + Err(ValidationError::BestOfSampling) => (), + _ => panic!("Unexpected not best of sampling") + } + + } + + #[tokio::test] + async fn test_validation_top_p(){ + let tokenizer = Some(get_tokenizer().await); + let max_best_of = 2; + let max_stop_sequence = 3; + let max_input_length = 4; + let max_total_tokens = 5; + let workers = 1; + let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); + match validation.validate(GenerateRequest{ + inputs: "Hello".to_string(), + parameters: GenerateParameters{ + top_p: Some(1.0), + ..default_parameters() + } + }).await{ + Err(ValidationError::TopP) => (), + _ => panic!("Unexpected top_p") + } + + match validation.validate(GenerateRequest{ + inputs: "Hello".to_string(), + parameters: GenerateParameters{ + top_p: Some(0.99), + max_new_tokens: 1, + ..default_parameters() + } + }).await{ + Ok(_) => (), + _ => panic!("Unexpected top_p error") + } + + let valid_request = validation.validate(GenerateRequest{ + inputs: "Hello".to_string(), + parameters: GenerateParameters{ + top_p: None, + max_new_tokens: 1, + ..default_parameters() + } + }).await.unwrap(); + // top_p == 1.0 is invalid for users to ask for but it's the default resolved value. + assert_eq!(valid_request.parameters.top_p, 1.0); + + + } }