diff --git a/.github/workflows/server-tests.yaml b/.github/workflows/tests.yaml similarity index 63% rename from .github/workflows/server-tests.yaml rename to .github/workflows/tests.yaml index 5bb4653a..a20a03f5 100644 --- a/.github/workflows/server-tests.yaml +++ b/.github/workflows/tests.yaml @@ -15,16 +15,25 @@ jobs: uses: actions/setup-python@v1 with: python-version: 3.9 + - name: Install Rust + uses: actions-rs/toolchain@v1 + with: + toolchain: 1.65.0 + override: true + components: rustfmt, clippy - name: Loading cache. uses: actions/cache@v2 id: model_cache with: path: ~/.cache/huggingface/ key: models - - name: Install server dependencies + - name: Install run: | - make install-server - - name: Run tests + make install + - name: Run server tests run: | pip install pytest pytest -sv server/tests + - name: Run Rust tests + run: | + cargo test diff --git a/Cargo.lock b/Cargo.lock index 0fd5c4bf..ec013c02 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1505,9 +1505,9 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.11.12" +version = "0.11.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "431949c384f4e2ae07605ccaa56d1d9d2ecdb5cadd4f9577ccfab29f2e5149fc" +checksum = "68cc60575865c7831548863cc02356512e3f1dc2f3f82cb837d7fc4cc8f3c97c" dependencies = [ "base64", "bytes", @@ -1607,9 +1607,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.87" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ce777b7b150d76b9cf60d28b55f5847135a003f7d7350c6be7a773508ce7d45" +checksum = "020ff22c755c2ed3f8cf162dbb41a7268d934702f3ed3631656ea597e08fc3db" dependencies = [ "itoa", "ryu", @@ -1804,6 +1804,8 @@ version = "0.1.0" dependencies = [ "clap 4.0.22", "ctrlc", + "reqwest", + "serde_json", "subprocess", "tracing", "tracing-subscriber", diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml index 1779c051..cf86aa00 100644 --- a/launcher/Cargo.toml +++ b/launcher/Cargo.toml @@ -7,7 +7,11 @@ description = "Text Generation Launcher" [dependencies] clap = { version = "4.0.15", features = ["derive", "env"] } -ctrlc = "3.2.3" +ctrlc = { version = "3.2.3", features = ["termination"] } subprocess = "0.2.9" tracing = "0.1.37" tracing-subscriber = { version = "0.3.16", features = ["json"] } + +[dev-dependencies] +reqwest = { version = "0.11.13", features = ["blocking", "json"] } +serde_json = "1.0.89" diff --git a/launcher/tests/integration_tests.rs b/launcher/tests/integration_tests.rs new file mode 100644 index 00000000..c0758491 --- /dev/null +++ b/launcher/tests/integration_tests.rs @@ -0,0 +1,102 @@ +use serde_json::Value; +use std::io::{BufRead, BufReader}; +use std::thread; +use std::thread::sleep; +use std::time::Duration; +use subprocess::{Popen, PopenConfig, Redirection}; + +fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Popen { + let argv = vec![ + "text-generation-launcher".to_string(), + "--model-name".to_string(), + model_name.clone(), + "--num-shard".to_string(), + num_shard.to_string(), + "--port".to_string(), + port.to_string(), + "--master-port".to_string(), + master_port.to_string(), + "--shard-uds-path".to_string(), + format!("/tmp/test-{}-{}-{}", num_shard, port, master_port), + ]; + + let mut launcher = Popen::create( + &argv, + PopenConfig { + stdout: Redirection::Pipe, + stderr: Redirection::Pipe, + ..Default::default() + }, + ) + .expect("Could not start launcher"); + + // Redirect STDOUT and STDERR to the console + let launcher_stdout = launcher.stdout.take().unwrap(); + let launcher_stderr = launcher.stderr.take().unwrap(); + + thread::spawn(move || { + let stdout = BufReader::new(launcher_stdout); + let stderr = BufReader::new(launcher_stderr); + for line in stdout.lines() { + println!("{}", line.unwrap()); + } + for line in stderr.lines() { + println!("{}", line.unwrap()); + } + }); + + for _ in 0..30 { + let health = reqwest::blocking::get(format!("http://localhost:{}/health", port)); + if health.is_ok() { + return launcher; + } + sleep(Duration::from_secs(2)); + } + + launcher.terminate().unwrap(); + launcher.wait().unwrap(); + panic!("failed to launch {}", model_name) +} + +fn test_model(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Value { + let mut launcher = start_launcher(model_name, num_shard, port, master_port); + + let data = r#" + { + "inputs": "Test request", + "parameters": { + "details": true + } + }"#; + let req: Value = serde_json::from_str(data).unwrap(); + + let client = reqwest::blocking::Client::new(); + let res = client + .post(format!("http://localhost:{}/generate", port)) + .json(&req) + .send(); + + launcher.terminate().unwrap(); + launcher.wait().unwrap(); + + let result: Value = res.unwrap().json().unwrap(); + result +} + +#[test] +fn test_bloom_560m() { + let result = test_model("bigscience/bloom-560m".to_string(), 1, 3000, 29500); + println!("{}", result); +} + +#[test] +fn test_bloom_560m_distributed() { + let result = test_model("bigscience/bloom-560m".to_string(), 2, 3001, 29501); + println!("{}", result); +} + +#[test] +fn test_mt0_base() { + let result = test_model("bigscience/mt0-base".to_string(), 1, 3002, 29502); + println!("{}", result); +}