mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
add expected values
This commit is contained in:
parent
7ab7c9a01f
commit
9978147004
20
Cargo.lock
generated
20
Cargo.lock
generated
@ -543,6 +543,12 @@ dependencies = [
|
|||||||
"miniz_oxide",
|
"miniz_oxide",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "float_eq"
|
||||||
|
version = "1.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fnv"
|
name = "fnv"
|
||||||
version = "1.0.7"
|
version = "1.0.7"
|
||||||
@ -1587,18 +1593,18 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde"
|
name = "serde"
|
||||||
version = "1.0.147"
|
version = "1.0.150"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d193d69bae983fc11a79df82342761dfbf28a99fc8d203dca4c3c1b590948965"
|
checksum = "e326c9ec8042f1b5da33252c8a37e9ffbd2c9bef0155215b6e6c80c790e05f91"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"serde_derive",
|
"serde_derive",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_derive"
|
name = "serde_derive"
|
||||||
version = "1.0.147"
|
version = "1.0.150"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4f1d362ca8fc9c3e3a7484440752472d68a6caa98f1ab81d99b5dfe517cec852"
|
checksum = "42a3df25b0713732468deadad63ab9da1f1fd75a48a15024b50363f128db627e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
@ -1724,9 +1730,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "syn"
|
name = "syn"
|
||||||
version = "1.0.103"
|
version = "1.0.105"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a864042229133ada95abf3b54fdc62ef5ccabe9515b64717bcb9a1919e59445d"
|
checksum = "60b9b43d45702de4c839cb9b51d9f529c5dd26a4aff255b42b1ebc03e88ee908"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
@ -1804,7 +1810,9 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"clap 4.0.22",
|
"clap 4.0.22",
|
||||||
"ctrlc",
|
"ctrlc",
|
||||||
|
"float_eq",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"subprocess",
|
"subprocess",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
@ -13,5 +13,7 @@ tracing = "0.1.37"
|
|||||||
tracing-subscriber = { version = "0.3.16", features = ["json"] }
|
tracing-subscriber = { version = "0.3.16", features = ["json"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
float_eq = "1.0.1"
|
||||||
reqwest = { version = "0.11.13", features = ["blocking", "json"] }
|
reqwest = { version = "0.11.13", features = ["blocking", "json"] }
|
||||||
|
serde = "1.0.150"
|
||||||
serde_json = "1.0.89"
|
serde_json = "1.0.89"
|
||||||
|
121
launcher/tests/bloom_560m.json
Normal file
121
launcher/tests/bloom_560m.json
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 20,
|
||||||
|
"tokens": [
|
||||||
|
[
|
||||||
|
10264,
|
||||||
|
"Test",
|
||||||
|
null
|
||||||
|
],
|
||||||
|
[
|
||||||
|
8821,
|
||||||
|
" request",
|
||||||
|
-11.895094
|
||||||
|
],
|
||||||
|
[
|
||||||
|
17,
|
||||||
|
".",
|
||||||
|
-1.8267941
|
||||||
|
],
|
||||||
|
[
|
||||||
|
1587,
|
||||||
|
"get",
|
||||||
|
-2.4674964
|
||||||
|
],
|
||||||
|
[
|
||||||
|
11,
|
||||||
|
"(",
|
||||||
|
-1.9060438
|
||||||
|
],
|
||||||
|
[
|
||||||
|
5,
|
||||||
|
"\"",
|
||||||
|
-1.2279553
|
||||||
|
],
|
||||||
|
[
|
||||||
|
4899,
|
||||||
|
"action",
|
||||||
|
-4.170306
|
||||||
|
],
|
||||||
|
[
|
||||||
|
5,
|
||||||
|
"\"",
|
||||||
|
-0.3247902
|
||||||
|
],
|
||||||
|
[
|
||||||
|
12,
|
||||||
|
")",
|
||||||
|
-1.0773602
|
||||||
|
],
|
||||||
|
[
|
||||||
|
30,
|
||||||
|
";",
|
||||||
|
-0.27640444
|
||||||
|
],
|
||||||
|
[
|
||||||
|
837,
|
||||||
|
"\n ",
|
||||||
|
-1.6970599
|
||||||
|
],
|
||||||
|
[
|
||||||
|
1320,
|
||||||
|
" if",
|
||||||
|
-1.4495552
|
||||||
|
],
|
||||||
|
[
|
||||||
|
375,
|
||||||
|
" (",
|
||||||
|
-0.2360998
|
||||||
|
],
|
||||||
|
[
|
||||||
|
4899,
|
||||||
|
"action",
|
||||||
|
-1.1916926
|
||||||
|
],
|
||||||
|
[
|
||||||
|
3535,
|
||||||
|
" ==",
|
||||||
|
-0.8918663
|
||||||
|
],
|
||||||
|
[
|
||||||
|
5109,
|
||||||
|
" null",
|
||||||
|
-0.39334255
|
||||||
|
],
|
||||||
|
[
|
||||||
|
12,
|
||||||
|
")",
|
||||||
|
-0.4321134
|
||||||
|
],
|
||||||
|
[
|
||||||
|
731,
|
||||||
|
" {",
|
||||||
|
-0.17701954
|
||||||
|
],
|
||||||
|
[
|
||||||
|
1260,
|
||||||
|
"\n ",
|
||||||
|
-0.07027287
|
||||||
|
],
|
||||||
|
[
|
||||||
|
10519,
|
||||||
|
" throw",
|
||||||
|
-1.3915133
|
||||||
|
],
|
||||||
|
[
|
||||||
|
2084,
|
||||||
|
" new",
|
||||||
|
-0.042013377
|
||||||
|
],
|
||||||
|
[
|
||||||
|
150858,
|
||||||
|
" RuntimeException",
|
||||||
|
-1.7330077
|
||||||
|
]
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": "Test request.get(\"action\");\n if (action == null) {\n throw new RuntimeException"
|
||||||
|
}
|
||||||
|
]
|
@ -1,9 +1,27 @@
|
|||||||
|
use std::fs::File;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::io::{BufRead, BufReader};
|
use std::io::{BufRead, BufReader};
|
||||||
|
use std::path::PathBuf;
|
||||||
use std::thread;
|
use std::thread;
|
||||||
use std::thread::sleep;
|
use std::thread::sleep;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
use float_eq::assert_float_eq;
|
||||||
use subprocess::{Popen, PopenConfig, Redirection};
|
use subprocess::{Popen, PopenConfig, Redirection};
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct Details {
|
||||||
|
finish_reason: String,
|
||||||
|
generated_tokens: u32,
|
||||||
|
tokens: Vec<(u32, String, Option<f32>)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct GeneratedText {
|
||||||
|
generated_text: String,
|
||||||
|
details: Details,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
|
fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
|
||||||
let argv = vec![
|
let argv = vec![
|
||||||
@ -58,7 +76,7 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
|
|||||||
panic!("failed to launch {}", model_name)
|
panic!("failed to launch {}", model_name)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn test_model(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Value {
|
fn test_model(model_name: String, num_shard: usize, port: usize, master_port: usize) -> GeneratedText {
|
||||||
let mut launcher = start_launcher(model_name, num_shard, port, master_port);
|
let mut launcher = start_launcher(model_name, num_shard, port, master_port);
|
||||||
|
|
||||||
let data = r#"
|
let data = r#"
|
||||||
@ -79,24 +97,60 @@ fn test_model(model_name: String, num_shard: usize, port: usize, master_port: us
|
|||||||
launcher.terminate().unwrap();
|
launcher.terminate().unwrap();
|
||||||
launcher.wait().unwrap();
|
launcher.wait().unwrap();
|
||||||
|
|
||||||
let result: Value = res.unwrap().json().unwrap();
|
let mut results: Vec<GeneratedText> = res.unwrap().json().unwrap();
|
||||||
result
|
results.pop().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fn read_json(name: &str) -> GeneratedText {
|
||||||
|
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||||
|
d.push("tests/");
|
||||||
|
d.push(name);
|
||||||
|
|
||||||
|
let file = File::open(d).unwrap();
|
||||||
|
let reader = BufReader::new(file);
|
||||||
|
|
||||||
|
let mut results: Vec<GeneratedText> = serde_json::from_reader(reader).unwrap();
|
||||||
|
results.pop().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compare_results(result: GeneratedText, expected: GeneratedText) {
|
||||||
|
assert_eq!(result.generated_text, expected.generated_text);
|
||||||
|
assert_eq!(result.details.finish_reason, expected.details.finish_reason);
|
||||||
|
assert_eq!(result.details.generated_tokens, expected.details.generated_tokens);
|
||||||
|
|
||||||
|
for (token, expected_token) in result.details.tokens.into_iter().zip(expected.details.tokens.into_iter()) {
|
||||||
|
assert_eq!(token.0, expected_token.0);
|
||||||
|
assert_eq!(token.1, expected_token.1);
|
||||||
|
if let Some(logprob) = token.2 {
|
||||||
|
let expected_logprob = expected_token.2.unwrap();
|
||||||
|
assert_float_eq!(logprob, expected_logprob, abs <= 0.001);
|
||||||
|
} else {
|
||||||
|
assert_eq!(token.2, expected_token.2);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_bloom_560m() {
|
fn test_bloom_560m() {
|
||||||
|
let expected = read_json("bloom_560m.json");
|
||||||
|
|
||||||
let result = test_model("bigscience/bloom-560m".to_string(), 1, 3000, 29500);
|
let result = test_model("bigscience/bloom-560m".to_string(), 1, 3000, 29500);
|
||||||
println!("{}", result);
|
compare_results(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_bloom_560m_distributed() {
|
fn test_bloom_560m_distributed() {
|
||||||
|
let expected = read_json("bloom_560m.json");
|
||||||
|
|
||||||
let result = test_model("bigscience/bloom-560m".to_string(), 2, 3001, 29501);
|
let result = test_model("bigscience/bloom-560m".to_string(), 2, 3001, 29501);
|
||||||
println!("{}", result);
|
compare_results(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_mt0_base() {
|
fn test_mt0_base() {
|
||||||
|
let expected = read_json("mt0_base.json");
|
||||||
|
|
||||||
let result = test_model("bigscience/mt0-base".to_string(), 1, 3002, 29502);
|
let result = test_model("bigscience/mt0-base".to_string(), 1, 3002, 29502);
|
||||||
println!("{}", result);
|
compare_results(result, expected);
|
||||||
}
|
}
|
||||||
|
116
launcher/tests/mt0_base.json
Normal file
116
launcher/tests/mt0_base.json
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 20,
|
||||||
|
"tokens": [
|
||||||
|
[
|
||||||
|
0,
|
||||||
|
"<pad>",
|
||||||
|
null
|
||||||
|
],
|
||||||
|
[
|
||||||
|
259,
|
||||||
|
"",
|
||||||
|
-1.3656927
|
||||||
|
],
|
||||||
|
[
|
||||||
|
215100,
|
||||||
|
"\"\"\"",
|
||||||
|
-2.6551573
|
||||||
|
],
|
||||||
|
[
|
||||||
|
46138,
|
||||||
|
"Test",
|
||||||
|
-1.8059857
|
||||||
|
],
|
||||||
|
[
|
||||||
|
287,
|
||||||
|
"the",
|
||||||
|
-1.2102449
|
||||||
|
],
|
||||||
|
[
|
||||||
|
259,
|
||||||
|
"",
|
||||||
|
-1.6057279
|
||||||
|
],
|
||||||
|
[
|
||||||
|
49076,
|
||||||
|
"contents",
|
||||||
|
-3.6060903
|
||||||
|
],
|
||||||
|
[
|
||||||
|
304,
|
||||||
|
"of",
|
||||||
|
-0.5270343
|
||||||
|
],
|
||||||
|
[
|
||||||
|
287,
|
||||||
|
"the",
|
||||||
|
-0.62522805
|
||||||
|
],
|
||||||
|
[
|
||||||
|
259,
|
||||||
|
"",
|
||||||
|
-1.4069618
|
||||||
|
],
|
||||||
|
[
|
||||||
|
49076,
|
||||||
|
"contents",
|
||||||
|
-2.621994
|
||||||
|
],
|
||||||
|
[
|
||||||
|
304,
|
||||||
|
"of",
|
||||||
|
-1.3172221
|
||||||
|
],
|
||||||
|
[
|
||||||
|
287,
|
||||||
|
"the",
|
||||||
|
-0.3501925
|
||||||
|
],
|
||||||
|
[
|
||||||
|
259,
|
||||||
|
"",
|
||||||
|
-0.7219573
|
||||||
|
],
|
||||||
|
[
|
||||||
|
49076,
|
||||||
|
"contents",
|
||||||
|
-1.0494149
|
||||||
|
],
|
||||||
|
[
|
||||||
|
260,
|
||||||
|
".",
|
||||||
|
-1.0803378
|
||||||
|
],
|
||||||
|
[
|
||||||
|
259,
|
||||||
|
"",
|
||||||
|
-0.32933083
|
||||||
|
],
|
||||||
|
[
|
||||||
|
215100,
|
||||||
|
"\"\"\"",
|
||||||
|
-0.11268901
|
||||||
|
],
|
||||||
|
[
|
||||||
|
2978,
|
||||||
|
"test",
|
||||||
|
-1.5846587
|
||||||
|
],
|
||||||
|
[
|
||||||
|
290,
|
||||||
|
"_",
|
||||||
|
-0.49796978
|
||||||
|
],
|
||||||
|
[
|
||||||
|
4125,
|
||||||
|
"test",
|
||||||
|
-2.0026445
|
||||||
|
]
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test"
|
||||||
|
}
|
||||||
|
]
|
Loading…
Reference in New Issue
Block a user