add expected values

This commit is contained in:
OlivierDehaene 2022-12-16 11:04:37 +01:00
parent 7ab7c9a01f
commit 9978147004
5 changed files with 314 additions and 13 deletions

20
Cargo.lock generated
View File

@ -543,6 +543,12 @@ dependencies = [
"miniz_oxide",
]
[[package]]
name = "float_eq"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853"
[[package]]
name = "fnv"
version = "1.0.7"
@ -1587,18 +1593,18 @@ dependencies = [
[[package]]
name = "serde"
version = "1.0.147"
version = "1.0.150"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d193d69bae983fc11a79df82342761dfbf28a99fc8d203dca4c3c1b590948965"
checksum = "e326c9ec8042f1b5da33252c8a37e9ffbd2c9bef0155215b6e6c80c790e05f91"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.147"
version = "1.0.150"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f1d362ca8fc9c3e3a7484440752472d68a6caa98f1ab81d99b5dfe517cec852"
checksum = "42a3df25b0713732468deadad63ab9da1f1fd75a48a15024b50363f128db627e"
dependencies = [
"proc-macro2",
"quote",
@ -1724,9 +1730,9 @@ dependencies = [
[[package]]
name = "syn"
version = "1.0.103"
version = "1.0.105"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a864042229133ada95abf3b54fdc62ef5ccabe9515b64717bcb9a1919e59445d"
checksum = "60b9b43d45702de4c839cb9b51d9f529c5dd26a4aff255b42b1ebc03e88ee908"
dependencies = [
"proc-macro2",
"quote",
@ -1804,7 +1810,9 @@ version = "0.1.0"
dependencies = [
"clap 4.0.22",
"ctrlc",
"float_eq",
"reqwest",
"serde",
"serde_json",
"subprocess",
"tracing",

View File

@ -13,5 +13,7 @@ tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = ["json"] }
[dev-dependencies]
float_eq = "1.0.1"
reqwest = { version = "0.11.13", features = ["blocking", "json"] }
serde = "1.0.150"
serde_json = "1.0.89"

View File

@ -0,0 +1,121 @@
[
{
"details": {
"finish_reason": "length",
"generated_tokens": 20,
"tokens": [
[
10264,
"Test",
null
],
[
8821,
" request",
-11.895094
],
[
17,
".",
-1.8267941
],
[
1587,
"get",
-2.4674964
],
[
11,
"(",
-1.9060438
],
[
5,
"\"",
-1.2279553
],
[
4899,
"action",
-4.170306
],
[
5,
"\"",
-0.3247902
],
[
12,
")",
-1.0773602
],
[
30,
";",
-0.27640444
],
[
837,
"\n ",
-1.6970599
],
[
1320,
" if",
-1.4495552
],
[
375,
" (",
-0.2360998
],
[
4899,
"action",
-1.1916926
],
[
3535,
" ==",
-0.8918663
],
[
5109,
" null",
-0.39334255
],
[
12,
")",
-0.4321134
],
[
731,
" {",
-0.17701954
],
[
1260,
"\n ",
-0.07027287
],
[
10519,
" throw",
-1.3915133
],
[
2084,
" new",
-0.042013377
],
[
150858,
" RuntimeException",
-1.7330077
]
]
},
"generated_text": "Test request.get(\"action\");\n if (action == null) {\n throw new RuntimeException"
}
]

View File

@ -1,9 +1,27 @@
use std::fs::File;
use serde_json::Value;
use std::io::{BufRead, BufReader};
use std::path::PathBuf;
use std::thread;
use std::thread::sleep;
use std::time::Duration;
use float_eq::assert_float_eq;
use subprocess::{Popen, PopenConfig, Redirection};
use serde::Deserialize;
#[derive(Deserialize)]
struct Details {
finish_reason: String,
generated_tokens: u32,
tokens: Vec<(u32, String, Option<f32>)>,
}
#[derive(Deserialize)]
struct GeneratedText {
generated_text: String,
details: Details,
}
fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
let argv = vec![
@ -28,7 +46,7 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
..Default::default()
},
)
.expect("Could not start launcher");
.expect("Could not start launcher");
// Redirect STDOUT and STDERR to the console
let launcher_stdout = launcher.stdout.take().unwrap();
@ -58,7 +76,7 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
panic!("failed to launch {}", model_name)
}
fn test_model(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Value {
fn test_model(model_name: String, num_shard: usize, port: usize, master_port: usize) -> GeneratedText {
let mut launcher = start_launcher(model_name, num_shard, port, master_port);
let data = r#"
@ -79,24 +97,60 @@ fn test_model(model_name: String, num_shard: usize, port: usize, master_port: us
launcher.terminate().unwrap();
launcher.wait().unwrap();
let result: Value = res.unwrap().json().unwrap();
result
let mut results: Vec<GeneratedText> = res.unwrap().json().unwrap();
results.pop().unwrap()
}
fn read_json(name: &str) -> GeneratedText {
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
d.push("tests/");
d.push(name);
let file = File::open(d).unwrap();
let reader = BufReader::new(file);
let mut results: Vec<GeneratedText> = serde_json::from_reader(reader).unwrap();
results.pop().unwrap()
}
fn compare_results(result: GeneratedText, expected: GeneratedText) {
assert_eq!(result.generated_text, expected.generated_text);
assert_eq!(result.details.finish_reason, expected.details.finish_reason);
assert_eq!(result.details.generated_tokens, expected.details.generated_tokens);
for (token, expected_token) in result.details.tokens.into_iter().zip(expected.details.tokens.into_iter()) {
assert_eq!(token.0, expected_token.0);
assert_eq!(token.1, expected_token.1);
if let Some(logprob) = token.2 {
let expected_logprob = expected_token.2.unwrap();
assert_float_eq!(logprob, expected_logprob, abs <= 0.001);
} else {
assert_eq!(token.2, expected_token.2);
}
}
}
#[test]
fn test_bloom_560m() {
let expected = read_json("bloom_560m.json");
let result = test_model("bigscience/bloom-560m".to_string(), 1, 3000, 29500);
println!("{}", result);
compare_results(result, expected);
}
#[test]
fn test_bloom_560m_distributed() {
let expected = read_json("bloom_560m.json");
let result = test_model("bigscience/bloom-560m".to_string(), 2, 3001, 29501);
println!("{}", result);
compare_results(result, expected);
}
#[test]
fn test_mt0_base() {
let expected = read_json("mt0_base.json");
let result = test_model("bigscience/mt0-base".to_string(), 1, 3002, 29502);
println!("{}", result);
compare_results(result, expected);
}

View File

@ -0,0 +1,116 @@
[
{
"details": {
"finish_reason": "length",
"generated_tokens": 20,
"tokens": [
[
0,
"<pad>",
null
],
[
259,
"",
-1.3656927
],
[
215100,
"\"\"\"",
-2.6551573
],
[
46138,
"Test",
-1.8059857
],
[
287,
"the",
-1.2102449
],
[
259,
"",
-1.6057279
],
[
49076,
"contents",
-3.6060903
],
[
304,
"of",
-0.5270343
],
[
287,
"the",
-0.62522805
],
[
259,
"",
-1.4069618
],
[
49076,
"contents",
-2.621994
],
[
304,
"of",
-1.3172221
],
[
287,
"the",
-0.3501925
],
[
259,
"",
-0.7219573
],
[
49076,
"contents",
-1.0494149
],
[
260,
".",
-1.0803378
],
[
259,
"",
-0.32933083
],
[
215100,
"\"\"\"",
-0.11268901
],
[
2978,
"test",
-1.5846587
],
[
290,
"_",
-0.49796978
],
[
4125,
"test",
-2.0026445
]
]
},
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test"
}
]