From 109c5af6158e2a39441bdd9496ae7b78b247ad3f Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 2 Feb 2023 18:37:07 +0100 Subject: [PATCH] finalize openAPI schemas --- launcher/tests/bloom_560m.json | 245 ++++++++++++++-------------- launcher/tests/integration_tests.rs | 23 ++- launcher/tests/mt0_base.json | 235 +++++++++++++------------- router/src/lib.rs | 41 +++-- router/src/server.rs | 73 ++++++--- 5 files changed, 331 insertions(+), 286 deletions(-) diff --git a/launcher/tests/bloom_560m.json b/launcher/tests/bloom_560m.json index 3a0a3d99..17e2571e 100644 --- a/launcher/tests/bloom_560m.json +++ b/launcher/tests/bloom_560m.json @@ -1,123 +1,122 @@ -[ - { - "details": { - "finish_reason": "length", - "generated_tokens": 20, - "prefill": [ - [ - 10264, - "Test", - null - ], - [ - 8821, - " request", - -11.895094 - ] - ], - "tokens": [ - [ - 17, - ".", - -1.8267941 - ], - [ - 1587, - "get", - -2.4674964 - ], - [ - 11, - "(", - -1.9060438 - ], - [ - 5, - "\"", - -1.2279553 - ], - [ - 4899, - "action", - -4.170306 - ], - [ - 5, - "\"", - -0.3247902 - ], - [ - 12, - ")", - -1.0773602 - ], - [ - 30, - ";", - -0.27640444 - ], - [ - 837, - "\n ", - -1.6970599 - ], - [ - 1320, - " if", - -1.4495552 - ], - [ - 375, - " (", - -0.2360998 - ], - [ - 4899, - "action", - -1.1916926 - ], - [ - 3535, - " ==", - -0.8918663 - ], - [ - 5109, - " null", - -0.39334255 - ], - [ - 12, - ")", - -0.4321134 - ], - [ - 731, - " {", - -0.17701954 - ], - [ - 1260, - "\n ", - -0.07027287 - ], - [ - 10519, - " throw", - -1.3915133 - ], - [ - 2084, - " new", - -0.042013377 - ], - [ - 150858, - " RuntimeException", - -1.7330077 - ] - ] - }, - "generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException" - } -] \ No newline at end of file +{ + "details": { + "finish_reason": "length", + "generated_tokens": 20, + "prefill": [ + { + "id": 10264, + "logprob": null, + "text": "Test" + }, + { + "id": 8821, + "logprob": -11.894989, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 17, + "logprob": -1.8267672, + "text": "." + }, + { + "id": 1587, + "logprob": -2.4674969, + "text": "get" + }, + { + "id": 11, + "logprob": -1.906001, + "text": "(" + }, + { + "id": 5, + "logprob": -1.2279545, + "text": "\"" + }, + { + "id": 4899, + "logprob": -4.170299, + "text": "action" + }, + { + "id": 5, + "logprob": -0.32478866, + "text": "\"" + }, + { + "id": 12, + "logprob": -1.0773665, + "text": ")" + }, + { + "id": 30, + "logprob": -0.27640742, + "text": ";" + }, + { + "id": 837, + "logprob": -1.6970354, + "text": "\n " + }, + { + "id": 1320, + "logprob": -1.4495516, + "text": " if" + }, + { + "id": 375, + "logprob": -0.23609057, + "text": " (" + }, + { + "id": 4899, + "logprob": -1.1916996, + "text": "action" + }, + { + "id": 3535, + "logprob": -0.8918753, + "text": " ==" + }, + { + "id": 5109, + "logprob": -0.3933342, + "text": " null" + }, + { + "id": 12, + "logprob": -0.43212673, + "text": ")" + }, + { + "id": 731, + "logprob": -0.17702064, + "text": " {" + }, + { + "id": 1260, + "logprob": -0.07027565, + "text": "\n " + }, + { + "id": 10519, + "logprob": -1.3915029, + "text": " throw" + }, + { + "id": 2084, + "logprob": -0.04201372, + "text": " new" + }, + { + "id": 150858, + "logprob": -1.7329919, + "text": " RuntimeException" + } + ] + }, + "generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException" +} \ No newline at end of file diff --git a/launcher/tests/integration_tests.rs b/launcher/tests/integration_tests.rs index 4f699f69..9336f36e 100644 --- a/launcher/tests/integration_tests.rs +++ b/launcher/tests/integration_tests.rs @@ -9,11 +9,18 @@ use std::thread::sleep; use std::time::Duration; use subprocess::{Popen, PopenConfig, Redirection}; +#[derive(Deserialize)] +pub struct Token { + id: u32, + text: String, + logprob: Option, +} + #[derive(Deserialize)] struct Details { finish_reason: String, generated_tokens: u32, - tokens: Vec<(u32, String, Option)>, + tokens: Vec, } #[derive(Deserialize)] @@ -109,8 +116,8 @@ fn read_json(name: &str) -> GeneratedText { let file = File::open(d).unwrap(); let reader = BufReader::new(file); - let mut results: Vec = serde_json::from_reader(reader).unwrap(); - results.pop().unwrap() + let result: GeneratedText = serde_json::from_reader(reader).unwrap(); + result } fn compare_results(result: GeneratedText, expected: GeneratedText) { @@ -127,13 +134,13 @@ fn compare_results(result: GeneratedText, expected: GeneratedText) { .into_iter() .zip(expected.details.tokens.into_iter()) { - assert_eq!(token.0, expected_token.0); - assert_eq!(token.1, expected_token.1); - if let Some(logprob) = token.2 { - let expected_logprob = expected_token.2.unwrap(); + assert_eq!(token.id, expected_token.id); + assert_eq!(token.text, expected_token.text); + if let Some(logprob) = token.logprob { + let expected_logprob = expected_token.logprob.unwrap(); assert_float_eq!(logprob, expected_logprob, abs <= 0.001); } else { - assert_eq!(token.2, expected_token.2); + assert_eq!(token.logprob, expected_token.logprob); } } } diff --git a/launcher/tests/mt0_base.json b/launcher/tests/mt0_base.json index 51cb8b5c..cee3bc47 100644 --- a/launcher/tests/mt0_base.json +++ b/launcher/tests/mt0_base.json @@ -1,118 +1,117 @@ -[ - { - "details": { - "finish_reason": "length", - "generated_tokens": 20, - "prefill": [ - [ - 0, - "", - null - ] - ], - "tokens": [ - [ - 259, - "", - -1.3656927 - ], - [ - 215100, - "\"\"\"", - -2.6551573 - ], - [ - 46138, - "Test", - -1.8059857 - ], - [ - 287, - "the", - -1.2102449 - ], - [ - 259, - "", - -1.6057279 - ], - [ - 49076, - "contents", - -3.6060903 - ], - [ - 304, - "of", - -0.5270343 - ], - [ - 287, - "the", - -0.62522805 - ], - [ - 259, - "", - -1.4069618 - ], - [ - 49076, - "contents", - -2.621994 - ], - [ - 304, - "of", - -1.3172221 - ], - [ - 287, - "the", - -0.3501925 - ], - [ - 259, - "", - -0.7219573 - ], - [ - 49076, - "contents", - -1.0494149 - ], - [ - 260, - ".", - -1.0803378 - ], - [ - 259, - "", - -0.32933083 - ], - [ - 215100, - "\"\"\"", - -0.11268901 - ], - [ - 2978, - "test", - -1.5846587 - ], - [ - 290, - "_", - -0.49796978 - ], - [ - 4125, - "test", - -2.0026445 - ] - ] - }, - "generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test" - } -] \ No newline at end of file +{ + "details": { + "finish_reason": "length", + "generated_tokens": 20, + "prefill": [ + { + "id": 0, + "logprob": null, + "text": "" + } + ], + "seed": null, + "tokens": [ + { + "id": 259, + "logprob": -1.3656927, + "text": "" + }, + { + "id": 215100, + "logprob": -2.6551573, + "text": "\"\"\"" + }, + { + "id": 46138, + "logprob": -1.8059857, + "text": "Test" + }, + { + "id": 287, + "logprob": -1.2102449, + "text": "the" + }, + { + "id": 259, + "logprob": -1.6057279, + "text": "" + }, + { + "id": 49076, + "logprob": -3.6060903, + "text": "contents" + }, + { + "id": 304, + "logprob": -0.5270343, + "text": "of" + }, + { + "id": 287, + "logprob": -0.62522805, + "text": "the" + }, + { + "id": 259, + "logprob": -1.4069618, + "text": "" + }, + { + "id": 49076, + "logprob": -2.621994, + "text": "contents" + }, + { + "id": 304, + "logprob": -1.3172221, + "text": "of" + }, + { + "id": 287, + "logprob": -0.3501925, + "text": "the" + }, + { + "id": 259, + "logprob": -0.7219573, + "text": "" + }, + { + "id": 49076, + "logprob": -1.0494149, + "text": "contents" + }, + { + "id": 260, + "logprob": -1.0803378, + "text": "." + }, + { + "id": 259, + "logprob": -0.32933083, + "text": "" + }, + { + "id": 215100, + "logprob": -0.11268901, + "text": "\"\"\"" + }, + { + "id": 2978, + "logprob": -1.5846587, + "text": "test" + }, + { + "id": 290, + "logprob": -0.49796978, + "text": "_" + }, + { + "id": 4125, + "logprob": -2.0026445, + "text": "test" + } + ] + }, + "generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test" +} \ No newline at end of file diff --git a/router/src/lib.rs b/router/src/lib.rs index 551d3cac..c88cae1d 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -22,7 +22,12 @@ pub(crate) struct GenerateParameters { #[schema(exclusive_minimum = 0, nullable = true, default = "null")] pub top_k: Option, #[serde(default)] - #[schema(exclusive_minimum = 0.0, maximum = 1.0, nullable = true, default = "null")] + #[schema( + exclusive_minimum = 0.0, + maximum = 1.0, + nullable = true, + default = "null" + )] pub top_p: Option, #[serde(default = "default_do_sample")] #[schema(default = "false")] @@ -31,7 +36,7 @@ pub(crate) struct GenerateParameters { #[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")] pub max_new_tokens: u32, #[serde(default)] - #[schema(max_items = 4, default = "null")] + #[schema(inline, max_items = 4, example = json!(["photographer"]))] pub stop: Vec, #[serde(default)] #[schema(default = "true")] @@ -72,22 +77,33 @@ pub(crate) struct GenerateRequest { #[derive(Debug, Serialize, ToSchema)] pub struct Token { + #[schema(example = 0)] id: u32, + #[schema(example = "test")] text: String, + #[schema(nullable = true, example = -0.34)] logprob: f32, } #[derive(Serialize, ToSchema)] +#[serde(rename_all(serialize = "snake_case"))] pub(crate) enum FinishReason { + #[schema(rename = "length")] Length, + #[serde(rename = "eos_token")] + #[schema(rename = "eos_token")] EndOfSequenceToken, + #[schema(rename = "stop_sequence")] StopSequence, } #[derive(Serialize, ToSchema)] pub(crate) struct Details { + #[schema(example = "length")] pub finish_reason: FinishReason, + #[schema(example = 1)] pub generated_tokens: u32, + #[schema(example = 42)] pub seed: Option, pub prefill: Option>, pub tokens: Option>, @@ -95,6 +111,7 @@ pub(crate) struct Details { #[derive(Serialize, ToSchema)] pub(crate) struct GenerateResponse { + #[schema(example = "test")] pub generated_text: String, #[serde(skip_serializing_if = "Option::is_none")] pub details: Option
, @@ -102,31 +119,25 @@ pub(crate) struct GenerateResponse { #[derive(Serialize, ToSchema)] pub(crate) struct StreamDetails { + #[schema(example = "length")] pub finish_reason: FinishReason, + #[schema(example = 1)] pub generated_tokens: u32, + #[schema(example = 42)] pub seed: Option, } #[derive(Serialize, ToSchema)] pub(crate) struct StreamResponse { pub token: Token, + #[schema(nullable = true, default = "null", example = "test")] pub generated_text: Option, + #[schema(nullable = true, default = "null")] pub details: Option, } -#[derive(Serialize, ToSchema)] -pub(crate) enum ErrorType { - #[schema(example = "Request failed during generation")] - GenerationError(String), - #[schema(example = "Model is overloaded")] - Overloaded(String), - #[schema(example = "Input validation error")] - ValidationError(String), - #[schema(example = "Incomplete generation")] - IncompleteGeneration(String), -} - #[derive(Serialize, ToSchema)] pub(crate) struct ErrorResponse { - pub error: ErrorType, + #[schema(inline)] + pub error: String, } diff --git a/router/src/server.rs b/router/src/server.rs index fd77a802..8094df1f 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,7 +1,7 @@ /// HTTP Server logic use crate::infer::{InferError, InferStreamResponse}; use crate::{ - Details, ErrorResponse, ErrorType, FinishReason, GenerateParameters, GenerateRequest, + Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, Infer, StreamDetails, StreamResponse, Token, Validation, }; use axum::extract::Extension; @@ -49,8 +49,24 @@ async fn health(infer: Extension) -> Result<(), (StatusCode, Json for FinishReason { } } -impl From for ErrorResponse { - fn from(err: InferError) -> Self { - let err_string = err.to_string(); - let error = match err { - InferError::GenerationError(_) => ErrorType::GenerationError(err_string), - InferError::Overloaded(_) => ErrorType::Overloaded(err_string), - InferError::ValidationError(_) => ErrorType::ValidationError(err_string), - InferError::IncompleteGeneration => ErrorType::IncompleteGeneration(err_string), - }; - ErrorResponse { error } - } -} - /// Convert to Axum supported formats impl From for (StatusCode, Json) { fn from(err: InferError) -> Self { @@ -384,14 +413,14 @@ impl From for (StatusCode, Json) { InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, }; - (status_code, Json(ErrorResponse::from(err))) + (status_code, Json(ErrorResponse{error: err.to_string()})) } } impl From for Event { fn from(err: InferError) -> Self { Event::default() - .json_data(ErrorResponse::from(err)) + .json_data(ErrorResponse{error: err.to_string()}) .unwrap() } }