From 99781470044b1182c55dee89dadfef8d2126e46a Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 16 Dec 2022 11:04:37 +0100 Subject: [PATCH] add expected values --- Cargo.lock | 20 +++-- launcher/Cargo.toml | 2 + launcher/tests/bloom_560m.json | 121 ++++++++++++++++++++++++++++ launcher/tests/integration_tests.rs | 68 ++++++++++++++-- launcher/tests/mt0_base.json | 116 ++++++++++++++++++++++++++ 5 files changed, 314 insertions(+), 13 deletions(-) create mode 100644 launcher/tests/bloom_560m.json create mode 100644 launcher/tests/mt0_base.json diff --git a/Cargo.lock b/Cargo.lock index ec013c02..752c4886 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -543,6 +543,12 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "float_eq" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853" + [[package]] name = "fnv" version = "1.0.7" @@ -1587,18 +1593,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.147" +version = "1.0.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d193d69bae983fc11a79df82342761dfbf28a99fc8d203dca4c3c1b590948965" +checksum = "e326c9ec8042f1b5da33252c8a37e9ffbd2c9bef0155215b6e6c80c790e05f91" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.147" +version = "1.0.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f1d362ca8fc9c3e3a7484440752472d68a6caa98f1ab81d99b5dfe517cec852" +checksum = "42a3df25b0713732468deadad63ab9da1f1fd75a48a15024b50363f128db627e" dependencies = [ "proc-macro2", "quote", @@ -1724,9 +1730,9 @@ dependencies = [ [[package]] name = "syn" -version = "1.0.103" +version = "1.0.105" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a864042229133ada95abf3b54fdc62ef5ccabe9515b64717bcb9a1919e59445d" +checksum = "60b9b43d45702de4c839cb9b51d9f529c5dd26a4aff255b42b1ebc03e88ee908" dependencies = [ "proc-macro2", "quote", @@ -1804,7 +1810,9 @@ version = "0.1.0" dependencies = [ "clap 4.0.22", "ctrlc", + "float_eq", "reqwest", + "serde", "serde_json", "subprocess", "tracing", diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml index cf86aa00..ecdef831 100644 --- a/launcher/Cargo.toml +++ b/launcher/Cargo.toml @@ -13,5 +13,7 @@ tracing = "0.1.37" tracing-subscriber = { version = "0.3.16", features = ["json"] } [dev-dependencies] +float_eq = "1.0.1" reqwest = { version = "0.11.13", features = ["blocking", "json"] } +serde = "1.0.150" serde_json = "1.0.89" diff --git a/launcher/tests/bloom_560m.json b/launcher/tests/bloom_560m.json new file mode 100644 index 00000000..d17f1ed4 --- /dev/null +++ b/launcher/tests/bloom_560m.json @@ -0,0 +1,121 @@ +[ + { + "details": { + "finish_reason": "length", + "generated_tokens": 20, + "tokens": [ + [ + 10264, + "Test", + null + ], + [ + 8821, + " request", + -11.895094 + ], + [ + 17, + ".", + -1.8267941 + ], + [ + 1587, + "get", + -2.4674964 + ], + [ + 11, + "(", + -1.9060438 + ], + [ + 5, + "\"", + -1.2279553 + ], + [ + 4899, + "action", + -4.170306 + ], + [ + 5, + "\"", + -0.3247902 + ], + [ + 12, + ")", + -1.0773602 + ], + [ + 30, + ";", + -0.27640444 + ], + [ + 837, + "\n ", + -1.6970599 + ], + [ + 1320, + " if", + -1.4495552 + ], + [ + 375, + " (", + -0.2360998 + ], + [ + 4899, + "action", + -1.1916926 + ], + [ + 3535, + " ==", + -0.8918663 + ], + [ + 5109, + " null", + -0.39334255 + ], + [ + 12, + ")", + -0.4321134 + ], + [ + 731, + " {", + -0.17701954 + ], + [ + 1260, + "\n ", + -0.07027287 + ], + [ + 10519, + " throw", + -1.3915133 + ], + [ + 2084, + " new", + -0.042013377 + ], + [ + 150858, + " RuntimeException", + -1.7330077 + ] + ] + }, + "generated_text": "Test request.get(\"action\");\n if (action == null) {\n throw new RuntimeException" + } +] \ No newline at end of file diff --git a/launcher/tests/integration_tests.rs b/launcher/tests/integration_tests.rs index c0758491..3e68f6be 100644 --- a/launcher/tests/integration_tests.rs +++ b/launcher/tests/integration_tests.rs @@ -1,9 +1,27 @@ +use std::fs::File; use serde_json::Value; use std::io::{BufRead, BufReader}; +use std::path::PathBuf; use std::thread; use std::thread::sleep; use std::time::Duration; +use float_eq::assert_float_eq; use subprocess::{Popen, PopenConfig, Redirection}; +use serde::Deserialize; + +#[derive(Deserialize)] +struct Details { + finish_reason: String, + generated_tokens: u32, + tokens: Vec<(u32, String, Option)>, +} + +#[derive(Deserialize)] +struct GeneratedText { + generated_text: String, + details: Details, +} + fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Popen { let argv = vec![ @@ -28,7 +46,7 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port ..Default::default() }, ) - .expect("Could not start launcher"); + .expect("Could not start launcher"); // Redirect STDOUT and STDERR to the console let launcher_stdout = launcher.stdout.take().unwrap(); @@ -58,7 +76,7 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port panic!("failed to launch {}", model_name) } -fn test_model(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Value { +fn test_model(model_name: String, num_shard: usize, port: usize, master_port: usize) -> GeneratedText { let mut launcher = start_launcher(model_name, num_shard, port, master_port); let data = r#" @@ -79,24 +97,60 @@ fn test_model(model_name: String, num_shard: usize, port: usize, master_port: us launcher.terminate().unwrap(); launcher.wait().unwrap(); - let result: Value = res.unwrap().json().unwrap(); - result + let mut results: Vec = res.unwrap().json().unwrap(); + results.pop().unwrap() +} + + +fn read_json(name: &str) -> GeneratedText { + let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + d.push("tests/"); + d.push(name); + + let file = File::open(d).unwrap(); + let reader = BufReader::new(file); + + let mut results: Vec = serde_json::from_reader(reader).unwrap(); + results.pop().unwrap() +} + +fn compare_results(result: GeneratedText, expected: GeneratedText) { + assert_eq!(result.generated_text, expected.generated_text); + assert_eq!(result.details.finish_reason, expected.details.finish_reason); + assert_eq!(result.details.generated_tokens, expected.details.generated_tokens); + + for (token, expected_token) in result.details.tokens.into_iter().zip(expected.details.tokens.into_iter()) { + assert_eq!(token.0, expected_token.0); + assert_eq!(token.1, expected_token.1); + if let Some(logprob) = token.2 { + let expected_logprob = expected_token.2.unwrap(); + assert_float_eq!(logprob, expected_logprob, abs <= 0.001); + } else { + assert_eq!(token.2, expected_token.2); + } + } } #[test] fn test_bloom_560m() { + let expected = read_json("bloom_560m.json"); + let result = test_model("bigscience/bloom-560m".to_string(), 1, 3000, 29500); - println!("{}", result); + compare_results(result, expected); } #[test] fn test_bloom_560m_distributed() { + let expected = read_json("bloom_560m.json"); + let result = test_model("bigscience/bloom-560m".to_string(), 2, 3001, 29501); - println!("{}", result); + compare_results(result, expected); } #[test] fn test_mt0_base() { + let expected = read_json("mt0_base.json"); + let result = test_model("bigscience/mt0-base".to_string(), 1, 3002, 29502); - println!("{}", result); + compare_results(result, expected); } diff --git a/launcher/tests/mt0_base.json b/launcher/tests/mt0_base.json new file mode 100644 index 00000000..1b772282 --- /dev/null +++ b/launcher/tests/mt0_base.json @@ -0,0 +1,116 @@ +[ + { + "details": { + "finish_reason": "length", + "generated_tokens": 20, + "tokens": [ + [ + 0, + "", + null + ], + [ + 259, + "", + -1.3656927 + ], + [ + 215100, + "\"\"\"", + -2.6551573 + ], + [ + 46138, + "Test", + -1.8059857 + ], + [ + 287, + "the", + -1.2102449 + ], + [ + 259, + "", + -1.6057279 + ], + [ + 49076, + "contents", + -3.6060903 + ], + [ + 304, + "of", + -0.5270343 + ], + [ + 287, + "the", + -0.62522805 + ], + [ + 259, + "", + -1.4069618 + ], + [ + 49076, + "contents", + -2.621994 + ], + [ + 304, + "of", + -1.3172221 + ], + [ + 287, + "the", + -0.3501925 + ], + [ + 259, + "", + -0.7219573 + ], + [ + 49076, + "contents", + -1.0494149 + ], + [ + 260, + ".", + -1.0803378 + ], + [ + 259, + "", + -0.32933083 + ], + [ + 215100, + "\"\"\"", + -0.11268901 + ], + [ + 2978, + "test", + -1.5846587 + ], + [ + 290, + "_", + -0.49796978 + ], + [ + 4125, + "test", + -2.0026445 + ] + ] + }, + "generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test" + } +] \ No newline at end of file