diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index e82e8b20..a2c2b7fb 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -67,6 +67,9 @@ jobs: run: | pip install pytest HF_HUB_ENABLE_HF_TRANSFER=1 pytest -sv server/tests + - name: Run Clippy + run: | + cargo clippy - name: Run Rust tests run: | cargo test diff --git a/Cargo.lock b/Cargo.lock index 72c3cb2c..e240fe9f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -183,21 +183,6 @@ dependencies = [ "tower-service", ] -[[package]] -name = "axum-test-server" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57f8f8627d32fe7e2c36b33de0e87dcdee4d6ac8619b9b892e5cc299ea4eed52" -dependencies = [ - "anyhow", - "axum", - "hyper", - "portpicker", - "serde", - "serde_json", - "tokio", -] - [[package]] name = "axum-tracing-opentelemetry" version = "0.10.0" @@ -1751,15 +1736,6 @@ version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26f6a7b87c2e435a3241addceeeff740ff8b7e76b74c13bf9acb17fa454ea00b" -[[package]] -name = "portpicker" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be97d76faf1bfab666e1375477b23fde79eccf0276e9b63b92a39d676a889ba9" -dependencies = [ - "rand", -] - [[package]] name = "ppv-lite86" version = "0.2.17" @@ -2431,7 +2407,6 @@ version = "0.6.0" dependencies = [ "async-stream", "axum", - "axum-test-server", "axum-tracing-opentelemetry", "clap", "flume", diff --git a/router/Cargo.toml b/router/Cargo.toml index 64aa65f3..4fa523a5 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -42,6 +42,3 @@ utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] } [build-dependencies] vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] } - -[dev-dependencies] -axum-test-server = "2.0.0" diff --git a/router/src/queue.rs b/router/src/queue.rs index d970ebf1..aee84129 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -141,6 +141,9 @@ impl State { // Get the next batch fn next_batch(&mut self, min_size: Option, token_budget: u32) -> Option { + + println!("Next batch {min_size:?} {token_budget:?}"); + println!("{:?}",self.entries); if self.entries.is_empty() { return None; } @@ -430,7 +433,17 @@ mod tests { let (entry3, _guard3) = default_entry(); queue.append(entry3); + // Not enough requests pending assert!(queue.next_batch(Some(2), 2).await.is_none()); + // Not enough token budget + assert!(queue.next_batch(Some(1), 0).await.is_none()); + // Ok + let (entries2, batch2, _) = queue.next_batch(Some(1), 2).await.unwrap(); + assert_eq!(entries2.len(), 1); + assert!(entries2.contains_key(&2)); + assert!(entries2.get(&2).unwrap().batch_time.is_some()); + assert_eq!(batch2.id, 1); + assert_eq!(batch2.size, 1); } #[tokio::test] diff --git a/router/src/server.rs b/router/src/server.rs index 356e8025..09b5c3ba 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -742,62 +742,3 @@ impl From for Event { } } -#[cfg(test)] -mod tests{ - use super::*; - use crate::tests::get_tokenizer; - use axum_test_server::TestServer; - use crate::default_parameters; - - #[tokio::test] - async fn test_health(){ - let tokenizer = Some(get_tokenizer().await); - let max_best_of = 2; - let max_stop_sequence = 3; - let max_input_length = 4; - let max_total_tokens = 5; - let workers = 1; - let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); - match validation.validate(GenerateRequest{ - inputs: "Hello".to_string(), - parameters: GenerateParameters{ - best_of: Some(2), - do_sample: false, - ..default_parameters() - } - }).await{ - Err(ValidationError::BestOfSampling) => (), - _ => panic!("Unexpected not best of sampling") - } - - let client = ShardedClient::connect_uds("/tmp/text-generation-test".to_string()).await.unwrap(); - let waiting_served_ratio = 1.2; - let max_batch_total_tokens = 100; - let max_waiting_tokens = 10; - let max_concurrent_requests = 10; - let requires_padding = false; - let infer = Infer::new( - client, - validation, - waiting_served_ratio, - max_batch_total_tokens, - max_waiting_tokens, - max_concurrent_requests, - requires_padding, - ); - let app = Router::new() - .route("/health", get(health)) - .layer(Extension(infer)) - .into_make_service(); - - // Run the server on a random address. - let server = TestServer::new(app); - - // Get the request. - let response = server - .get("/health") - .await; - - assert_eq!(response.contents, "pong!"); - } -} diff --git a/router/src/validation.rs b/router/src/validation.rs index f87d4d34..ff2fe89d 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -383,7 +383,7 @@ pub enum ValidationError { mod tests{ use super::*; use crate::default_parameters; - use std::io::Write; + use crate::tests::get_tokenizer; #[tokio::test] async fn test_validation_max_new_tokens(){ @@ -402,15 +402,6 @@ mod tests{ } } - async fn get_tokenizer() -> Tokenizer{ - if !std::path::Path::new("tokenizer.json").exists(){ - let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap(); - let mut file = std::fs::File::create("tokenizer.json").unwrap(); - file.write_all(&content).unwrap(); - } - Tokenizer::from_file("tokenizer.json").unwrap() - } - #[tokio::test] async fn test_validation_input_length(){ let tokenizer = Some(get_tokenizer().await);