feat(launcher): Add integration tests

This commit is contained in:
OlivierDehaene 2022-12-15 19:28:55 +01:00
parent 32a253063d
commit 75ab65fb31
4 changed files with 125 additions and 8 deletions

View File

@ -15,16 +15,25 @@ jobs:
uses: actions/setup-python@v1 uses: actions/setup-python@v1
with: with:
python-version: 3.9 python-version: 3.9
- name: Install Rust
uses: actions-rs/toolchain@v1
with:
toolchain: 1.65.0
override: true
components: rustfmt, clippy
- name: Loading cache. - name: Loading cache.
uses: actions/cache@v2 uses: actions/cache@v2
id: model_cache id: model_cache
with: with:
path: ~/.cache/huggingface/ path: ~/.cache/huggingface/
key: models key: models
- name: Install server dependencies - name: Install
run: | run: |
make install-server make install
- name: Run tests - name: Run server tests
run: | run: |
pip install pytest pip install pytest
pytest -sv server/tests pytest -sv server/tests
- name: Run Rust tests
run: |
cargo test

10
Cargo.lock generated
View File

@ -1505,9 +1505,9 @@ dependencies = [
[[package]] [[package]]
name = "reqwest" name = "reqwest"
version = "0.11.12" version = "0.11.13"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "431949c384f4e2ae07605ccaa56d1d9d2ecdb5cadd4f9577ccfab29f2e5149fc" checksum = "68cc60575865c7831548863cc02356512e3f1dc2f3f82cb837d7fc4cc8f3c97c"
dependencies = [ dependencies = [
"base64", "base64",
"bytes", "bytes",
@ -1607,9 +1607,9 @@ dependencies = [
[[package]] [[package]]
name = "serde_json" name = "serde_json"
version = "1.0.87" version = "1.0.89"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ce777b7b150d76b9cf60d28b55f5847135a003f7d7350c6be7a773508ce7d45" checksum = "020ff22c755c2ed3f8cf162dbb41a7268d934702f3ed3631656ea597e08fc3db"
dependencies = [ dependencies = [
"itoa", "itoa",
"ryu", "ryu",
@ -1804,6 +1804,8 @@ version = "0.1.0"
dependencies = [ dependencies = [
"clap 4.0.22", "clap 4.0.22",
"ctrlc", "ctrlc",
"reqwest",
"serde_json",
"subprocess", "subprocess",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",

View File

@ -7,7 +7,11 @@ description = "Text Generation Launcher"
[dependencies] [dependencies]
clap = { version = "4.0.15", features = ["derive", "env"] } clap = { version = "4.0.15", features = ["derive", "env"] }
ctrlc = "3.2.3" ctrlc = { version = "3.2.3", features = ["termination"] }
subprocess = "0.2.9" subprocess = "0.2.9"
tracing = "0.1.37" tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = ["json"] } tracing-subscriber = { version = "0.3.16", features = ["json"] }
[dev-dependencies]
reqwest = { version = "0.11.13", features = ["blocking", "json"] }
serde_json = "1.0.89"

View File

@ -0,0 +1,102 @@
use serde_json::Value;
use std::io::{BufRead, BufReader};
use std::thread;
use std::thread::sleep;
use std::time::Duration;
use subprocess::{Popen, PopenConfig, Redirection};
fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
let argv = vec![
"text-generation-launcher".to_string(),
"--model-name".to_string(),
model_name.clone(),
"--num-shard".to_string(),
num_shard.to_string(),
"--port".to_string(),
port.to_string(),
"--master-port".to_string(),
master_port.to_string(),
"--shard-uds-path".to_string(),
format!("/tmp/test-{}-{}-{}", num_shard, port, master_port),
];
let mut launcher = Popen::create(
&argv,
PopenConfig {
stdout: Redirection::Pipe,
stderr: Redirection::Pipe,
..Default::default()
},
)
.expect("Could not start launcher");
// Redirect STDOUT and STDERR to the console
let launcher_stdout = launcher.stdout.take().unwrap();
let launcher_stderr = launcher.stderr.take().unwrap();
thread::spawn(move || {
let stdout = BufReader::new(launcher_stdout);
let stderr = BufReader::new(launcher_stderr);
for line in stdout.lines() {
println!("{}", line.unwrap());
}
for line in stderr.lines() {
println!("{}", line.unwrap());
}
});
for _ in 0..30 {
let health = reqwest::blocking::get(format!("http://localhost:{}/health", port));
if health.is_ok() {
return launcher;
}
sleep(Duration::from_secs(2));
}
launcher.terminate().unwrap();
launcher.wait().unwrap();
panic!("failed to launch {}", model_name)
}
fn test_model(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Value {
let mut launcher = start_launcher(model_name, num_shard, port, master_port);
let data = r#"
{
"inputs": "Test request",
"parameters": {
"details": true
}
}"#;
let req: Value = serde_json::from_str(data).unwrap();
let client = reqwest::blocking::Client::new();
let res = client
.post(format!("http://localhost:{}/generate", port))
.json(&req)
.send();
launcher.terminate().unwrap();
launcher.wait().unwrap();
let result: Value = res.unwrap().json().unwrap();
result
}
#[test]
fn test_bloom_560m() {
let result = test_model("bigscience/bloom-560m".to_string(), 1, 3000, 29500);
println!("{}", result);
}
#[test]
fn test_bloom_560m_distributed() {
let result = test_model("bigscience/bloom-560m".to_string(), 2, 3001, 29501);
println!("{}", result);
}
#[test]
fn test_mt0_base() {
let result = test_model("bigscience/mt0-base".to_string(), 1, 3002, 29502);
println!("{}", result);
}