finalize openAPI schemas

This commit is contained in:
OlivierDehaene 2023-02-02 18:37:07 +01:00
parent 2878c43cc5
commit 109c5af615
5 changed files with 331 additions and 286 deletions

View File

@ -1,123 +1,122 @@
[
{
"details": {
"finish_reason": "length",
"generated_tokens": 20,
"prefill": [
[
10264,
"Test",
null
],
[
8821,
" request",
-11.895094
]
{
"id": 10264,
"logprob": null,
"text": "Test"
},
{
"id": 8821,
"logprob": -11.894989,
"text": " request"
}
],
"seed": null,
"tokens": [
[
17,
".",
-1.8267941
],
[
1587,
"get",
-2.4674964
],
[
11,
"(",
-1.9060438
],
[
5,
"\"",
-1.2279553
],
[
4899,
"action",
-4.170306
],
[
5,
"\"",
-0.3247902
],
[
12,
")",
-1.0773602
],
[
30,
";",
-0.27640444
],
[
837,
"\n ",
-1.6970599
],
[
1320,
" if",
-1.4495552
],
[
375,
" (",
-0.2360998
],
[
4899,
"action",
-1.1916926
],
[
3535,
" ==",
-0.8918663
],
[
5109,
" null",
-0.39334255
],
[
12,
")",
-0.4321134
],
[
731,
" {",
-0.17701954
],
[
1260,
"\n ",
-0.07027287
],
[
10519,
" throw",
-1.3915133
],
[
2084,
" new",
-0.042013377
],
[
150858,
" RuntimeException",
-1.7330077
]
{
"id": 17,
"logprob": -1.8267672,
"text": "."
},
{
"id": 1587,
"logprob": -2.4674969,
"text": "get"
},
{
"id": 11,
"logprob": -1.906001,
"text": "("
},
{
"id": 5,
"logprob": -1.2279545,
"text": "\""
},
{
"id": 4899,
"logprob": -4.170299,
"text": "action"
},
{
"id": 5,
"logprob": -0.32478866,
"text": "\""
},
{
"id": 12,
"logprob": -1.0773665,
"text": ")"
},
{
"id": 30,
"logprob": -0.27640742,
"text": ";"
},
{
"id": 837,
"logprob": -1.6970354,
"text": "\n "
},
{
"id": 1320,
"logprob": -1.4495516,
"text": " if"
},
{
"id": 375,
"logprob": -0.23609057,
"text": " ("
},
{
"id": 4899,
"logprob": -1.1916996,
"text": "action"
},
{
"id": 3535,
"logprob": -0.8918753,
"text": " =="
},
{
"id": 5109,
"logprob": -0.3933342,
"text": " null"
},
{
"id": 12,
"logprob": -0.43212673,
"text": ")"
},
{
"id": 731,
"logprob": -0.17702064,
"text": " {"
},
{
"id": 1260,
"logprob": -0.07027565,
"text": "\n "
},
{
"id": 10519,
"logprob": -1.3915029,
"text": " throw"
},
{
"id": 2084,
"logprob": -0.04201372,
"text": " new"
},
{
"id": 150858,
"logprob": -1.7329919,
"text": " RuntimeException"
}
]
},
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
}
]

View File

@ -9,11 +9,18 @@ use std::thread::sleep;
use std::time::Duration;
use subprocess::{Popen, PopenConfig, Redirection};
#[derive(Deserialize)]
pub struct Token {
id: u32,
text: String,
logprob: Option<f32>,
}
#[derive(Deserialize)]
struct Details {
finish_reason: String,
generated_tokens: u32,
tokens: Vec<(u32, String, Option<f32>)>,
tokens: Vec<Token>,
}
#[derive(Deserialize)]
@ -109,8 +116,8 @@ fn read_json(name: &str) -> GeneratedText {
let file = File::open(d).unwrap();
let reader = BufReader::new(file);
let mut results: Vec<GeneratedText> = serde_json::from_reader(reader).unwrap();
results.pop().unwrap()
let result: GeneratedText = serde_json::from_reader(reader).unwrap();
result
}
fn compare_results(result: GeneratedText, expected: GeneratedText) {
@ -127,13 +134,13 @@ fn compare_results(result: GeneratedText, expected: GeneratedText) {
.into_iter()
.zip(expected.details.tokens.into_iter())
{
assert_eq!(token.0, expected_token.0);
assert_eq!(token.1, expected_token.1);
if let Some(logprob) = token.2 {
let expected_logprob = expected_token.2.unwrap();
assert_eq!(token.id, expected_token.id);
assert_eq!(token.text, expected_token.text);
if let Some(logprob) = token.logprob {
let expected_logprob = expected_token.logprob.unwrap();
assert_float_eq!(logprob, expected_logprob, abs <= 0.001);
} else {
assert_eq!(token.2, expected_token.2);
assert_eq!(token.logprob, expected_token.logprob);
}
}
}

View File

@ -1,118 +1,117 @@
[
{
"details": {
"finish_reason": "length",
"generated_tokens": 20,
"prefill": [
[
0,
"<pad>",
null
]
{
"id": 0,
"logprob": null,
"text": "<pad>"
}
],
"seed": null,
"tokens": [
[
259,
"",
-1.3656927
],
[
215100,
"\"\"\"",
-2.6551573
],
[
46138,
"Test",
-1.8059857
],
[
287,
"the",
-1.2102449
],
[
259,
"",
-1.6057279
],
[
49076,
"contents",
-3.6060903
],
[
304,
"of",
-0.5270343
],
[
287,
"the",
-0.62522805
],
[
259,
"",
-1.4069618
],
[
49076,
"contents",
-2.621994
],
[
304,
"of",
-1.3172221
],
[
287,
"the",
-0.3501925
],
[
259,
"",
-0.7219573
],
[
49076,
"contents",
-1.0494149
],
[
260,
".",
-1.0803378
],
[
259,
"",
-0.32933083
],
[
215100,
"\"\"\"",
-0.11268901
],
[
2978,
"test",
-1.5846587
],
[
290,
"_",
-0.49796978
],
[
4125,
"test",
-2.0026445
]
{
"id": 259,
"logprob": -1.3656927,
"text": ""
},
{
"id": 215100,
"logprob": -2.6551573,
"text": "\"\"\""
},
{
"id": 46138,
"logprob": -1.8059857,
"text": "Test"
},
{
"id": 287,
"logprob": -1.2102449,
"text": "the"
},
{
"id": 259,
"logprob": -1.6057279,
"text": ""
},
{
"id": 49076,
"logprob": -3.6060903,
"text": "contents"
},
{
"id": 304,
"logprob": -0.5270343,
"text": "of"
},
{
"id": 287,
"logprob": -0.62522805,
"text": "the"
},
{
"id": 259,
"logprob": -1.4069618,
"text": ""
},
{
"id": 49076,
"logprob": -2.621994,
"text": "contents"
},
{
"id": 304,
"logprob": -1.3172221,
"text": "of"
},
{
"id": 287,
"logprob": -0.3501925,
"text": "the"
},
{
"id": 259,
"logprob": -0.7219573,
"text": ""
},
{
"id": 49076,
"logprob": -1.0494149,
"text": "contents"
},
{
"id": 260,
"logprob": -1.0803378,
"text": "."
},
{
"id": 259,
"logprob": -0.32933083,
"text": ""
},
{
"id": 215100,
"logprob": -0.11268901,
"text": "\"\"\""
},
{
"id": 2978,
"logprob": -1.5846587,
"text": "test"
},
{
"id": 290,
"logprob": -0.49796978,
"text": "_"
},
{
"id": 4125,
"logprob": -2.0026445,
"text": "test"
}
]
},
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test"
}
]

View File

@ -22,7 +22,12 @@ pub(crate) struct GenerateParameters {
#[schema(exclusive_minimum = 0, nullable = true, default = "null")]
pub top_k: Option<i32>,
#[serde(default)]
#[schema(exclusive_minimum = 0.0, maximum = 1.0, nullable = true, default = "null")]
#[schema(
exclusive_minimum = 0.0,
maximum = 1.0,
nullable = true,
default = "null"
)]
pub top_p: Option<f32>,
#[serde(default = "default_do_sample")]
#[schema(default = "false")]
@ -31,7 +36,7 @@ pub(crate) struct GenerateParameters {
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
pub max_new_tokens: u32,
#[serde(default)]
#[schema(max_items = 4, default = "null")]
#[schema(inline, max_items = 4, example = json!(["photographer"]))]
pub stop: Vec<String>,
#[serde(default)]
#[schema(default = "true")]
@ -72,22 +77,33 @@ pub(crate) struct GenerateRequest {
#[derive(Debug, Serialize, ToSchema)]
pub struct Token {
#[schema(example = 0)]
id: u32,
#[schema(example = "test")]
text: String,
#[schema(nullable = true, example = -0.34)]
logprob: f32,
}
#[derive(Serialize, ToSchema)]
#[serde(rename_all(serialize = "snake_case"))]
pub(crate) enum FinishReason {
#[schema(rename = "length")]
Length,
#[serde(rename = "eos_token")]
#[schema(rename = "eos_token")]
EndOfSequenceToken,
#[schema(rename = "stop_sequence")]
StopSequence,
}
#[derive(Serialize, ToSchema)]
pub(crate) struct Details {
#[schema(example = "length")]
pub finish_reason: FinishReason,
#[schema(example = 1)]
pub generated_tokens: u32,
#[schema(example = 42)]
pub seed: Option<u64>,
pub prefill: Option<Vec<Token>>,
pub tokens: Option<Vec<Token>>,
@ -95,6 +111,7 @@ pub(crate) struct Details {
#[derive(Serialize, ToSchema)]
pub(crate) struct GenerateResponse {
#[schema(example = "test")]
pub generated_text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<Details>,
@ -102,31 +119,25 @@ pub(crate) struct GenerateResponse {
#[derive(Serialize, ToSchema)]
pub(crate) struct StreamDetails {
#[schema(example = "length")]
pub finish_reason: FinishReason,
#[schema(example = 1)]
pub generated_tokens: u32,
#[schema(example = 42)]
pub seed: Option<u64>,
}
#[derive(Serialize, ToSchema)]
pub(crate) struct StreamResponse {
pub token: Token,
#[schema(nullable = true, default = "null", example = "test")]
pub generated_text: Option<String>,
#[schema(nullable = true, default = "null")]
pub details: Option<StreamDetails>,
}
#[derive(Serialize, ToSchema)]
pub(crate) enum ErrorType {
#[schema(example = "Request failed during generation")]
GenerationError(String),
#[schema(example = "Model is overloaded")]
Overloaded(String),
#[schema(example = "Input validation error")]
ValidationError(String),
#[schema(example = "Incomplete generation")]
IncompleteGeneration(String),
}
#[derive(Serialize, ToSchema)]
pub(crate) struct ErrorResponse {
pub error: ErrorType,
#[schema(inline)]
pub error: String,
}

View File

@ -1,7 +1,7 @@
/// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse};
use crate::{
Details, ErrorResponse, ErrorType, FinishReason, GenerateParameters, GenerateRequest,
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
GenerateResponse, Infer, StreamDetails, StreamResponse, Token, Validation,
};
use axum::extract::Extension;
@ -49,8 +49,24 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
Ok(())
}
/// Generate method
#[utoipa::path(post, path = "/generate", request_body = GenerateRequest)]
/// Generate tokens
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/generate",
request_body = GenerateRequest,
responses(
(status = 200, description = "Generated Text", body = [GenerateResponse]),
(status = 424, description = "Generation Error", body = [ErrorResponse],
example = json!({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = [ErrorResponse],
example = json!({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = [ErrorResponse],
example = json!({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = [ErrorResponse],
example = json!({"error": "Incomplete generation"})),
)
)]
#[instrument(
skip(infer),
fields(
@ -135,9 +151,29 @@ async fn generate(
Ok((headers, Json(response)))
}
/// Generate stream method
#[utoipa::path(post, path = "/generate_stream")]
#[instrument(
/// Generate a stream of token using Server Side Events
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/generate_stream",
request_body = GenerateRequest,
responses(
(status = 200, description = "Generated Text", body = [StreamResponse],
content_type="text/event-stream "),
(status = 424, description = "Generation Error", body = [ErrorResponse],
example = json!({"error": "Request failed during generation"}),
content_type="text/event-stream "),
(status = 429, description = "Model is overloaded", body = [ErrorResponse],
example = json!({"error": "Model is overloaded"}),
content_type="text/event-stream "),
(status = 422, description = "Input validation error", body = [ErrorResponse],
example = json!({"error": "Input validation error"}),
content_type="text/event-stream "),
(status = 500, description = "Incomplete generation", body = [ErrorResponse],
example = json!({"error": "Incomplete generation"}),
content_type="text/event-stream "),
)
)]#[instrument(
skip(infer),
fields(
total_time,
@ -285,11 +321,17 @@ pub async fn run(
StreamResponse,
StreamDetails,
ErrorResponse,
ErrorType
)
),
tags(
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
),
info(
title = "Text Generation Inference",
license(
name = "Apache 2.0",
url = "https://www.apache.org/licenses/LICENSE-2.0"
)
)
)]
struct ApiDoc;
@ -361,19 +403,6 @@ impl From<i32> for FinishReason {
}
}
impl From<InferError> for ErrorResponse {
fn from(err: InferError) -> Self {
let err_string = err.to_string();
let error = match err {
InferError::GenerationError(_) => ErrorType::GenerationError(err_string),
InferError::Overloaded(_) => ErrorType::Overloaded(err_string),
InferError::ValidationError(_) => ErrorType::ValidationError(err_string),
InferError::IncompleteGeneration => ErrorType::IncompleteGeneration(err_string),
};
ErrorResponse { error }
}
}
/// Convert to Axum supported formats
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
fn from(err: InferError) -> Self {
@ -384,14 +413,14 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
};
(status_code, Json(ErrorResponse::from(err)))
(status_code, Json(ErrorResponse{error: err.to_string()}))
}
}
impl From<InferError> for Event {
fn from(err: InferError) -> Self {
Event::default()
.json_data(ErrorResponse::from(err))
.json_data(ErrorResponse{error: err.to_string()})
.unwrap()
}
}