mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
feat(launcher): Add integration tests
This commit is contained in:
parent
32a253063d
commit
75ab65fb31
@ -15,16 +15,25 @@ jobs:
|
|||||||
uses: actions/setup-python@v1
|
uses: actions/setup-python@v1
|
||||||
with:
|
with:
|
||||||
python-version: 3.9
|
python-version: 3.9
|
||||||
|
- name: Install Rust
|
||||||
|
uses: actions-rs/toolchain@v1
|
||||||
|
with:
|
||||||
|
toolchain: 1.65.0
|
||||||
|
override: true
|
||||||
|
components: rustfmt, clippy
|
||||||
- name: Loading cache.
|
- name: Loading cache.
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v2
|
||||||
id: model_cache
|
id: model_cache
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/huggingface/
|
path: ~/.cache/huggingface/
|
||||||
key: models
|
key: models
|
||||||
- name: Install server dependencies
|
- name: Install
|
||||||
run: |
|
run: |
|
||||||
make install-server
|
make install
|
||||||
- name: Run tests
|
- name: Run server tests
|
||||||
run: |
|
run: |
|
||||||
pip install pytest
|
pip install pytest
|
||||||
pytest -sv server/tests
|
pytest -sv server/tests
|
||||||
|
- name: Run Rust tests
|
||||||
|
run: |
|
||||||
|
cargo test
|
10
Cargo.lock
generated
10
Cargo.lock
generated
@ -1505,9 +1505,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "reqwest"
|
name = "reqwest"
|
||||||
version = "0.11.12"
|
version = "0.11.13"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "431949c384f4e2ae07605ccaa56d1d9d2ecdb5cadd4f9577ccfab29f2e5149fc"
|
checksum = "68cc60575865c7831548863cc02356512e3f1dc2f3f82cb837d7fc4cc8f3c97c"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"base64",
|
"base64",
|
||||||
"bytes",
|
"bytes",
|
||||||
@ -1607,9 +1607,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_json"
|
name = "serde_json"
|
||||||
version = "1.0.87"
|
version = "1.0.89"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6ce777b7b150d76b9cf60d28b55f5847135a003f7d7350c6be7a773508ce7d45"
|
checksum = "020ff22c755c2ed3f8cf162dbb41a7268d934702f3ed3631656ea597e08fc3db"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"itoa",
|
"itoa",
|
||||||
"ryu",
|
"ryu",
|
||||||
@ -1804,6 +1804,8 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"clap 4.0.22",
|
"clap 4.0.22",
|
||||||
"ctrlc",
|
"ctrlc",
|
||||||
|
"reqwest",
|
||||||
|
"serde_json",
|
||||||
"subprocess",
|
"subprocess",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
|
@ -7,7 +7,11 @@ description = "Text Generation Launcher"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
clap = { version = "4.0.15", features = ["derive", "env"] }
|
clap = { version = "4.0.15", features = ["derive", "env"] }
|
||||||
ctrlc = "3.2.3"
|
ctrlc = { version = "3.2.3", features = ["termination"] }
|
||||||
subprocess = "0.2.9"
|
subprocess = "0.2.9"
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-subscriber = { version = "0.3.16", features = ["json"] }
|
tracing-subscriber = { version = "0.3.16", features = ["json"] }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
reqwest = { version = "0.11.13", features = ["blocking", "json"] }
|
||||||
|
serde_json = "1.0.89"
|
||||||
|
102
launcher/tests/integration_tests.rs
Normal file
102
launcher/tests/integration_tests.rs
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
use serde_json::Value;
|
||||||
|
use std::io::{BufRead, BufReader};
|
||||||
|
use std::thread;
|
||||||
|
use std::thread::sleep;
|
||||||
|
use std::time::Duration;
|
||||||
|
use subprocess::{Popen, PopenConfig, Redirection};
|
||||||
|
|
||||||
|
fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
|
||||||
|
let argv = vec![
|
||||||
|
"text-generation-launcher".to_string(),
|
||||||
|
"--model-name".to_string(),
|
||||||
|
model_name.clone(),
|
||||||
|
"--num-shard".to_string(),
|
||||||
|
num_shard.to_string(),
|
||||||
|
"--port".to_string(),
|
||||||
|
port.to_string(),
|
||||||
|
"--master-port".to_string(),
|
||||||
|
master_port.to_string(),
|
||||||
|
"--shard-uds-path".to_string(),
|
||||||
|
format!("/tmp/test-{}-{}-{}", num_shard, port, master_port),
|
||||||
|
];
|
||||||
|
|
||||||
|
let mut launcher = Popen::create(
|
||||||
|
&argv,
|
||||||
|
PopenConfig {
|
||||||
|
stdout: Redirection::Pipe,
|
||||||
|
stderr: Redirection::Pipe,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.expect("Could not start launcher");
|
||||||
|
|
||||||
|
// Redirect STDOUT and STDERR to the console
|
||||||
|
let launcher_stdout = launcher.stdout.take().unwrap();
|
||||||
|
let launcher_stderr = launcher.stderr.take().unwrap();
|
||||||
|
|
||||||
|
thread::spawn(move || {
|
||||||
|
let stdout = BufReader::new(launcher_stdout);
|
||||||
|
let stderr = BufReader::new(launcher_stderr);
|
||||||
|
for line in stdout.lines() {
|
||||||
|
println!("{}", line.unwrap());
|
||||||
|
}
|
||||||
|
for line in stderr.lines() {
|
||||||
|
println!("{}", line.unwrap());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
for _ in 0..30 {
|
||||||
|
let health = reqwest::blocking::get(format!("http://localhost:{}/health", port));
|
||||||
|
if health.is_ok() {
|
||||||
|
return launcher;
|
||||||
|
}
|
||||||
|
sleep(Duration::from_secs(2));
|
||||||
|
}
|
||||||
|
|
||||||
|
launcher.terminate().unwrap();
|
||||||
|
launcher.wait().unwrap();
|
||||||
|
panic!("failed to launch {}", model_name)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn test_model(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Value {
|
||||||
|
let mut launcher = start_launcher(model_name, num_shard, port, master_port);
|
||||||
|
|
||||||
|
let data = r#"
|
||||||
|
{
|
||||||
|
"inputs": "Test request",
|
||||||
|
"parameters": {
|
||||||
|
"details": true
|
||||||
|
}
|
||||||
|
}"#;
|
||||||
|
let req: Value = serde_json::from_str(data).unwrap();
|
||||||
|
|
||||||
|
let client = reqwest::blocking::Client::new();
|
||||||
|
let res = client
|
||||||
|
.post(format!("http://localhost:{}/generate", port))
|
||||||
|
.json(&req)
|
||||||
|
.send();
|
||||||
|
|
||||||
|
launcher.terminate().unwrap();
|
||||||
|
launcher.wait().unwrap();
|
||||||
|
|
||||||
|
let result: Value = res.unwrap().json().unwrap();
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_bloom_560m() {
|
||||||
|
let result = test_model("bigscience/bloom-560m".to_string(), 1, 3000, 29500);
|
||||||
|
println!("{}", result);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_bloom_560m_distributed() {
|
||||||
|
let result = test_model("bigscience/bloom-560m".to_string(), 2, 3001, 29501);
|
||||||
|
println!("{}", result);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mt0_base() {
|
||||||
|
let result = test_model("bigscience/mt0-base".to_string(), 1, 3002, 29502);
|
||||||
|
println!("{}", result);
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user