finalize openAPI schemas

This commit is contained in:
OlivierDehaene 2023-02-02 18:37:07 +01:00
parent 2878c43cc5
commit 109c5af615
5 changed files with 331 additions and 286 deletions

View File

@ -1,123 +1,122 @@
[ {
{ "details": {
"details": { "finish_reason": "length",
"finish_reason": "length", "generated_tokens": 20,
"generated_tokens": 20, "prefill": [
"prefill": [ {
[ "id": 10264,
10264, "logprob": null,
"Test", "text": "Test"
null },
], {
[ "id": 8821,
8821, "logprob": -11.894989,
" request", "text": " request"
-11.895094 }
] ],
], "seed": null,
"tokens": [ "tokens": [
[ {
17, "id": 17,
".", "logprob": -1.8267672,
-1.8267941 "text": "."
], },
[ {
1587, "id": 1587,
"get", "logprob": -2.4674969,
-2.4674964 "text": "get"
], },
[ {
11, "id": 11,
"(", "logprob": -1.906001,
-1.9060438 "text": "("
], },
[ {
5, "id": 5,
"\"", "logprob": -1.2279545,
-1.2279553 "text": "\""
], },
[ {
4899, "id": 4899,
"action", "logprob": -4.170299,
-4.170306 "text": "action"
], },
[ {
5, "id": 5,
"\"", "logprob": -0.32478866,
-0.3247902 "text": "\""
], },
[ {
12, "id": 12,
")", "logprob": -1.0773665,
-1.0773602 "text": ")"
], },
[ {
30, "id": 30,
";", "logprob": -0.27640742,
-0.27640444 "text": ";"
], },
[ {
837, "id": 837,
"\n ", "logprob": -1.6970354,
-1.6970599 "text": "\n "
], },
[ {
1320, "id": 1320,
" if", "logprob": -1.4495516,
-1.4495552 "text": " if"
], },
[ {
375, "id": 375,
" (", "logprob": -0.23609057,
-0.2360998 "text": " ("
], },
[ {
4899, "id": 4899,
"action", "logprob": -1.1916996,
-1.1916926 "text": "action"
], },
[ {
3535, "id": 3535,
" ==", "logprob": -0.8918753,
-0.8918663 "text": " =="
], },
[ {
5109, "id": 5109,
" null", "logprob": -0.3933342,
-0.39334255 "text": " null"
], },
[ {
12, "id": 12,
")", "logprob": -0.43212673,
-0.4321134 "text": ")"
], },
[ {
731, "id": 731,
" {", "logprob": -0.17702064,
-0.17701954 "text": " {"
], },
[ {
1260, "id": 1260,
"\n ", "logprob": -0.07027565,
-0.07027287 "text": "\n "
], },
[ {
10519, "id": 10519,
" throw", "logprob": -1.3915029,
-1.3915133 "text": " throw"
], },
[ {
2084, "id": 2084,
" new", "logprob": -0.04201372,
-0.042013377 "text": " new"
], },
[ {
150858, "id": 150858,
" RuntimeException", "logprob": -1.7329919,
-1.7330077 "text": " RuntimeException"
] }
] ]
}, },
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException" "generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
} }
]

View File

@ -9,11 +9,18 @@ use std::thread::sleep;
use std::time::Duration; use std::time::Duration;
use subprocess::{Popen, PopenConfig, Redirection}; use subprocess::{Popen, PopenConfig, Redirection};
#[derive(Deserialize)]
pub struct Token {
id: u32,
text: String,
logprob: Option<f32>,
}
#[derive(Deserialize)] #[derive(Deserialize)]
struct Details { struct Details {
finish_reason: String, finish_reason: String,
generated_tokens: u32, generated_tokens: u32,
tokens: Vec<(u32, String, Option<f32>)>, tokens: Vec<Token>,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -109,8 +116,8 @@ fn read_json(name: &str) -> GeneratedText {
let file = File::open(d).unwrap(); let file = File::open(d).unwrap();
let reader = BufReader::new(file); let reader = BufReader::new(file);
let mut results: Vec<GeneratedText> = serde_json::from_reader(reader).unwrap(); let result: GeneratedText = serde_json::from_reader(reader).unwrap();
results.pop().unwrap() result
} }
fn compare_results(result: GeneratedText, expected: GeneratedText) { fn compare_results(result: GeneratedText, expected: GeneratedText) {
@ -127,13 +134,13 @@ fn compare_results(result: GeneratedText, expected: GeneratedText) {
.into_iter() .into_iter()
.zip(expected.details.tokens.into_iter()) .zip(expected.details.tokens.into_iter())
{ {
assert_eq!(token.0, expected_token.0); assert_eq!(token.id, expected_token.id);
assert_eq!(token.1, expected_token.1); assert_eq!(token.text, expected_token.text);
if let Some(logprob) = token.2 { if let Some(logprob) = token.logprob {
let expected_logprob = expected_token.2.unwrap(); let expected_logprob = expected_token.logprob.unwrap();
assert_float_eq!(logprob, expected_logprob, abs <= 0.001); assert_float_eq!(logprob, expected_logprob, abs <= 0.001);
} else { } else {
assert_eq!(token.2, expected_token.2); assert_eq!(token.logprob, expected_token.logprob);
} }
} }
} }

View File

@ -1,118 +1,117 @@
[ {
{ "details": {
"details": { "finish_reason": "length",
"finish_reason": "length", "generated_tokens": 20,
"generated_tokens": 20, "prefill": [
"prefill": [ {
[ "id": 0,
0, "logprob": null,
"<pad>", "text": "<pad>"
null }
] ],
], "seed": null,
"tokens": [ "tokens": [
[ {
259, "id": 259,
"", "logprob": -1.3656927,
-1.3656927 "text": ""
], },
[ {
215100, "id": 215100,
"\"\"\"", "logprob": -2.6551573,
-2.6551573 "text": "\"\"\""
], },
[ {
46138, "id": 46138,
"Test", "logprob": -1.8059857,
-1.8059857 "text": "Test"
], },
[ {
287, "id": 287,
"the", "logprob": -1.2102449,
-1.2102449 "text": "the"
], },
[ {
259, "id": 259,
"", "logprob": -1.6057279,
-1.6057279 "text": ""
], },
[ {
49076, "id": 49076,
"contents", "logprob": -3.6060903,
-3.6060903 "text": "contents"
], },
[ {
304, "id": 304,
"of", "logprob": -0.5270343,
-0.5270343 "text": "of"
], },
[ {
287, "id": 287,
"the", "logprob": -0.62522805,
-0.62522805 "text": "the"
], },
[ {
259, "id": 259,
"", "logprob": -1.4069618,
-1.4069618 "text": ""
], },
[ {
49076, "id": 49076,
"contents", "logprob": -2.621994,
-2.621994 "text": "contents"
], },
[ {
304, "id": 304,
"of", "logprob": -1.3172221,
-1.3172221 "text": "of"
], },
[ {
287, "id": 287,
"the", "logprob": -0.3501925,
-0.3501925 "text": "the"
], },
[ {
259, "id": 259,
"", "logprob": -0.7219573,
-0.7219573 "text": ""
], },
[ {
49076, "id": 49076,
"contents", "logprob": -1.0494149,
-1.0494149 "text": "contents"
], },
[ {
260, "id": 260,
".", "logprob": -1.0803378,
-1.0803378 "text": "."
], },
[ {
259, "id": 259,
"", "logprob": -0.32933083,
-0.32933083 "text": ""
], },
[ {
215100, "id": 215100,
"\"\"\"", "logprob": -0.11268901,
-0.11268901 "text": "\"\"\""
], },
[ {
2978, "id": 2978,
"test", "logprob": -1.5846587,
-1.5846587 "text": "test"
], },
[ {
290, "id": 290,
"_", "logprob": -0.49796978,
-0.49796978 "text": "_"
], },
[ {
4125, "id": 4125,
"test", "logprob": -2.0026445,
-2.0026445 "text": "test"
] }
] ]
}, },
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test" "generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test"
} }
]

View File

@ -22,7 +22,12 @@ pub(crate) struct GenerateParameters {
#[schema(exclusive_minimum = 0, nullable = true, default = "null")] #[schema(exclusive_minimum = 0, nullable = true, default = "null")]
pub top_k: Option<i32>, pub top_k: Option<i32>,
#[serde(default)] #[serde(default)]
#[schema(exclusive_minimum = 0.0, maximum = 1.0, nullable = true, default = "null")] #[schema(
exclusive_minimum = 0.0,
maximum = 1.0,
nullable = true,
default = "null"
)]
pub top_p: Option<f32>, pub top_p: Option<f32>,
#[serde(default = "default_do_sample")] #[serde(default = "default_do_sample")]
#[schema(default = "false")] #[schema(default = "false")]
@ -31,7 +36,7 @@ pub(crate) struct GenerateParameters {
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")] #[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
pub max_new_tokens: u32, pub max_new_tokens: u32,
#[serde(default)] #[serde(default)]
#[schema(max_items = 4, default = "null")] #[schema(inline, max_items = 4, example = json!(["photographer"]))]
pub stop: Vec<String>, pub stop: Vec<String>,
#[serde(default)] #[serde(default)]
#[schema(default = "true")] #[schema(default = "true")]
@ -72,22 +77,33 @@ pub(crate) struct GenerateRequest {
#[derive(Debug, Serialize, ToSchema)] #[derive(Debug, Serialize, ToSchema)]
pub struct Token { pub struct Token {
#[schema(example = 0)]
id: u32, id: u32,
#[schema(example = "test")]
text: String, text: String,
#[schema(nullable = true, example = -0.34)]
logprob: f32, logprob: f32,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
#[serde(rename_all(serialize = "snake_case"))]
pub(crate) enum FinishReason { pub(crate) enum FinishReason {
#[schema(rename = "length")]
Length, Length,
#[serde(rename = "eos_token")]
#[schema(rename = "eos_token")]
EndOfSequenceToken, EndOfSequenceToken,
#[schema(rename = "stop_sequence")]
StopSequence, StopSequence,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
pub(crate) struct Details { pub(crate) struct Details {
#[schema(example = "length")]
pub finish_reason: FinishReason, pub finish_reason: FinishReason,
#[schema(example = 1)]
pub generated_tokens: u32, pub generated_tokens: u32,
#[schema(example = 42)]
pub seed: Option<u64>, pub seed: Option<u64>,
pub prefill: Option<Vec<Token>>, pub prefill: Option<Vec<Token>>,
pub tokens: Option<Vec<Token>>, pub tokens: Option<Vec<Token>>,
@ -95,6 +111,7 @@ pub(crate) struct Details {
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
pub(crate) struct GenerateResponse { pub(crate) struct GenerateResponse {
#[schema(example = "test")]
pub generated_text: String, pub generated_text: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<Details>, pub details: Option<Details>,
@ -102,31 +119,25 @@ pub(crate) struct GenerateResponse {
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
pub(crate) struct StreamDetails { pub(crate) struct StreamDetails {
#[schema(example = "length")]
pub finish_reason: FinishReason, pub finish_reason: FinishReason,
#[schema(example = 1)]
pub generated_tokens: u32, pub generated_tokens: u32,
#[schema(example = 42)]
pub seed: Option<u64>, pub seed: Option<u64>,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
pub(crate) struct StreamResponse { pub(crate) struct StreamResponse {
pub token: Token, pub token: Token,
#[schema(nullable = true, default = "null", example = "test")]
pub generated_text: Option<String>, pub generated_text: Option<String>,
#[schema(nullable = true, default = "null")]
pub details: Option<StreamDetails>, pub details: Option<StreamDetails>,
} }
#[derive(Serialize, ToSchema)]
pub(crate) enum ErrorType {
#[schema(example = "Request failed during generation")]
GenerationError(String),
#[schema(example = "Model is overloaded")]
Overloaded(String),
#[schema(example = "Input validation error")]
ValidationError(String),
#[schema(example = "Incomplete generation")]
IncompleteGeneration(String),
}
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
pub(crate) struct ErrorResponse { pub(crate) struct ErrorResponse {
pub error: ErrorType, #[schema(inline)]
pub error: String,
} }

View File

@ -1,7 +1,7 @@
/// HTTP Server logic /// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse}; use crate::infer::{InferError, InferStreamResponse};
use crate::{ use crate::{
Details, ErrorResponse, ErrorType, FinishReason, GenerateParameters, GenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
GenerateResponse, Infer, StreamDetails, StreamResponse, Token, Validation, GenerateResponse, Infer, StreamDetails, StreamResponse, Token, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
@ -49,8 +49,24 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
Ok(()) Ok(())
} }
/// Generate method /// Generate tokens
#[utoipa::path(post, path = "/generate", request_body = GenerateRequest)] #[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/generate",
request_body = GenerateRequest,
responses(
(status = 200, description = "Generated Text", body = [GenerateResponse]),
(status = 424, description = "Generation Error", body = [ErrorResponse],
example = json!({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = [ErrorResponse],
example = json!({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = [ErrorResponse],
example = json!({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = [ErrorResponse],
example = json!({"error": "Incomplete generation"})),
)
)]
#[instrument( #[instrument(
skip(infer), skip(infer),
fields( fields(
@ -135,9 +151,29 @@ async fn generate(
Ok((headers, Json(response))) Ok((headers, Json(response)))
} }
/// Generate stream method /// Generate a stream of token using Server Side Events
#[utoipa::path(post, path = "/generate_stream")] #[utoipa::path(
#[instrument( post,
tag = "Text Generation Inference",
path = "/generate_stream",
request_body = GenerateRequest,
responses(
(status = 200, description = "Generated Text", body = [StreamResponse],
content_type="text/event-stream "),
(status = 424, description = "Generation Error", body = [ErrorResponse],
example = json!({"error": "Request failed during generation"}),
content_type="text/event-stream "),
(status = 429, description = "Model is overloaded", body = [ErrorResponse],
example = json!({"error": "Model is overloaded"}),
content_type="text/event-stream "),
(status = 422, description = "Input validation error", body = [ErrorResponse],
example = json!({"error": "Input validation error"}),
content_type="text/event-stream "),
(status = 500, description = "Incomplete generation", body = [ErrorResponse],
example = json!({"error": "Incomplete generation"}),
content_type="text/event-stream "),
)
)]#[instrument(
skip(infer), skip(infer),
fields( fields(
total_time, total_time,
@ -285,11 +321,17 @@ pub async fn run(
StreamResponse, StreamResponse,
StreamDetails, StreamDetails,
ErrorResponse, ErrorResponse,
ErrorType
) )
), ),
tags( tags(
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API") (name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
),
info(
title = "Text Generation Inference",
license(
name = "Apache 2.0",
url = "https://www.apache.org/licenses/LICENSE-2.0"
)
) )
)] )]
struct ApiDoc; struct ApiDoc;
@ -361,19 +403,6 @@ impl From<i32> for FinishReason {
} }
} }
impl From<InferError> for ErrorResponse {
fn from(err: InferError) -> Self {
let err_string = err.to_string();
let error = match err {
InferError::GenerationError(_) => ErrorType::GenerationError(err_string),
InferError::Overloaded(_) => ErrorType::Overloaded(err_string),
InferError::ValidationError(_) => ErrorType::ValidationError(err_string),
InferError::IncompleteGeneration => ErrorType::IncompleteGeneration(err_string),
};
ErrorResponse { error }
}
}
/// Convert to Axum supported formats /// Convert to Axum supported formats
impl From<InferError> for (StatusCode, Json<ErrorResponse>) { impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
fn from(err: InferError) -> Self { fn from(err: InferError) -> Self {
@ -384,14 +413,14 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
}; };
(status_code, Json(ErrorResponse::from(err))) (status_code, Json(ErrorResponse{error: err.to_string()}))
} }
} }
impl From<InferError> for Event { impl From<InferError> for Event {
fn from(err: InferError) -> Self { fn from(err: InferError) -> Self {
Event::default() Event::default()
.json_data(ErrorResponse::from(err)) .json_data(ErrorResponse{error: err.to_string()})
.unwrap() .unwrap()
} }
} }