mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
finalize openAPI schemas
This commit is contained in:
parent
2878c43cc5
commit
109c5af615
@ -1,123 +1,122 @@
|
|||||||
[
|
{
|
||||||
{
|
"details": {
|
||||||
"details": {
|
"finish_reason": "length",
|
||||||
"finish_reason": "length",
|
"generated_tokens": 20,
|
||||||
"generated_tokens": 20,
|
"prefill": [
|
||||||
"prefill": [
|
{
|
||||||
[
|
"id": 10264,
|
||||||
10264,
|
"logprob": null,
|
||||||
"Test",
|
"text": "Test"
|
||||||
null
|
},
|
||||||
],
|
{
|
||||||
[
|
"id": 8821,
|
||||||
8821,
|
"logprob": -11.894989,
|
||||||
" request",
|
"text": " request"
|
||||||
-11.895094
|
}
|
||||||
]
|
],
|
||||||
],
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
[
|
{
|
||||||
17,
|
"id": 17,
|
||||||
".",
|
"logprob": -1.8267672,
|
||||||
-1.8267941
|
"text": "."
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
1587,
|
"id": 1587,
|
||||||
"get",
|
"logprob": -2.4674969,
|
||||||
-2.4674964
|
"text": "get"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
11,
|
"id": 11,
|
||||||
"(",
|
"logprob": -1.906001,
|
||||||
-1.9060438
|
"text": "("
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
5,
|
"id": 5,
|
||||||
"\"",
|
"logprob": -1.2279545,
|
||||||
-1.2279553
|
"text": "\""
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
4899,
|
"id": 4899,
|
||||||
"action",
|
"logprob": -4.170299,
|
||||||
-4.170306
|
"text": "action"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
5,
|
"id": 5,
|
||||||
"\"",
|
"logprob": -0.32478866,
|
||||||
-0.3247902
|
"text": "\""
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
12,
|
"id": 12,
|
||||||
")",
|
"logprob": -1.0773665,
|
||||||
-1.0773602
|
"text": ")"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
30,
|
"id": 30,
|
||||||
";",
|
"logprob": -0.27640742,
|
||||||
-0.27640444
|
"text": ";"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
837,
|
"id": 837,
|
||||||
"\n ",
|
"logprob": -1.6970354,
|
||||||
-1.6970599
|
"text": "\n "
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
1320,
|
"id": 1320,
|
||||||
" if",
|
"logprob": -1.4495516,
|
||||||
-1.4495552
|
"text": " if"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
375,
|
"id": 375,
|
||||||
" (",
|
"logprob": -0.23609057,
|
||||||
-0.2360998
|
"text": " ("
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
4899,
|
"id": 4899,
|
||||||
"action",
|
"logprob": -1.1916996,
|
||||||
-1.1916926
|
"text": "action"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
3535,
|
"id": 3535,
|
||||||
" ==",
|
"logprob": -0.8918753,
|
||||||
-0.8918663
|
"text": " =="
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
5109,
|
"id": 5109,
|
||||||
" null",
|
"logprob": -0.3933342,
|
||||||
-0.39334255
|
"text": " null"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
12,
|
"id": 12,
|
||||||
")",
|
"logprob": -0.43212673,
|
||||||
-0.4321134
|
"text": ")"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
731,
|
"id": 731,
|
||||||
" {",
|
"logprob": -0.17702064,
|
||||||
-0.17701954
|
"text": " {"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
1260,
|
"id": 1260,
|
||||||
"\n ",
|
"logprob": -0.07027565,
|
||||||
-0.07027287
|
"text": "\n "
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
10519,
|
"id": 10519,
|
||||||
" throw",
|
"logprob": -1.3915029,
|
||||||
-1.3915133
|
"text": " throw"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
2084,
|
"id": 2084,
|
||||||
" new",
|
"logprob": -0.04201372,
|
||||||
-0.042013377
|
"text": " new"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
150858,
|
"id": 150858,
|
||||||
" RuntimeException",
|
"logprob": -1.7329919,
|
||||||
-1.7330077
|
"text": " RuntimeException"
|
||||||
]
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
|
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
|
||||||
}
|
}
|
||||||
]
|
|
@ -9,11 +9,18 @@ use std::thread::sleep;
|
|||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use subprocess::{Popen, PopenConfig, Redirection};
|
use subprocess::{Popen, PopenConfig, Redirection};
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct Token {
|
||||||
|
id: u32,
|
||||||
|
text: String,
|
||||||
|
logprob: Option<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct Details {
|
struct Details {
|
||||||
finish_reason: String,
|
finish_reason: String,
|
||||||
generated_tokens: u32,
|
generated_tokens: u32,
|
||||||
tokens: Vec<(u32, String, Option<f32>)>,
|
tokens: Vec<Token>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
@ -109,8 +116,8 @@ fn read_json(name: &str) -> GeneratedText {
|
|||||||
let file = File::open(d).unwrap();
|
let file = File::open(d).unwrap();
|
||||||
let reader = BufReader::new(file);
|
let reader = BufReader::new(file);
|
||||||
|
|
||||||
let mut results: Vec<GeneratedText> = serde_json::from_reader(reader).unwrap();
|
let result: GeneratedText = serde_json::from_reader(reader).unwrap();
|
||||||
results.pop().unwrap()
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compare_results(result: GeneratedText, expected: GeneratedText) {
|
fn compare_results(result: GeneratedText, expected: GeneratedText) {
|
||||||
@ -127,13 +134,13 @@ fn compare_results(result: GeneratedText, expected: GeneratedText) {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.zip(expected.details.tokens.into_iter())
|
.zip(expected.details.tokens.into_iter())
|
||||||
{
|
{
|
||||||
assert_eq!(token.0, expected_token.0);
|
assert_eq!(token.id, expected_token.id);
|
||||||
assert_eq!(token.1, expected_token.1);
|
assert_eq!(token.text, expected_token.text);
|
||||||
if let Some(logprob) = token.2 {
|
if let Some(logprob) = token.logprob {
|
||||||
let expected_logprob = expected_token.2.unwrap();
|
let expected_logprob = expected_token.logprob.unwrap();
|
||||||
assert_float_eq!(logprob, expected_logprob, abs <= 0.001);
|
assert_float_eq!(logprob, expected_logprob, abs <= 0.001);
|
||||||
} else {
|
} else {
|
||||||
assert_eq!(token.2, expected_token.2);
|
assert_eq!(token.logprob, expected_token.logprob);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,118 +1,117 @@
|
|||||||
[
|
{
|
||||||
{
|
"details": {
|
||||||
"details": {
|
"finish_reason": "length",
|
||||||
"finish_reason": "length",
|
"generated_tokens": 20,
|
||||||
"generated_tokens": 20,
|
"prefill": [
|
||||||
"prefill": [
|
{
|
||||||
[
|
"id": 0,
|
||||||
0,
|
"logprob": null,
|
||||||
"<pad>",
|
"text": "<pad>"
|
||||||
null
|
}
|
||||||
]
|
],
|
||||||
],
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
[
|
{
|
||||||
259,
|
"id": 259,
|
||||||
"",
|
"logprob": -1.3656927,
|
||||||
-1.3656927
|
"text": ""
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
215100,
|
"id": 215100,
|
||||||
"\"\"\"",
|
"logprob": -2.6551573,
|
||||||
-2.6551573
|
"text": "\"\"\""
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
46138,
|
"id": 46138,
|
||||||
"Test",
|
"logprob": -1.8059857,
|
||||||
-1.8059857
|
"text": "Test"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
287,
|
"id": 287,
|
||||||
"the",
|
"logprob": -1.2102449,
|
||||||
-1.2102449
|
"text": "the"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
259,
|
"id": 259,
|
||||||
"",
|
"logprob": -1.6057279,
|
||||||
-1.6057279
|
"text": ""
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
49076,
|
"id": 49076,
|
||||||
"contents",
|
"logprob": -3.6060903,
|
||||||
-3.6060903
|
"text": "contents"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
304,
|
"id": 304,
|
||||||
"of",
|
"logprob": -0.5270343,
|
||||||
-0.5270343
|
"text": "of"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
287,
|
"id": 287,
|
||||||
"the",
|
"logprob": -0.62522805,
|
||||||
-0.62522805
|
"text": "the"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
259,
|
"id": 259,
|
||||||
"",
|
"logprob": -1.4069618,
|
||||||
-1.4069618
|
"text": ""
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
49076,
|
"id": 49076,
|
||||||
"contents",
|
"logprob": -2.621994,
|
||||||
-2.621994
|
"text": "contents"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
304,
|
"id": 304,
|
||||||
"of",
|
"logprob": -1.3172221,
|
||||||
-1.3172221
|
"text": "of"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
287,
|
"id": 287,
|
||||||
"the",
|
"logprob": -0.3501925,
|
||||||
-0.3501925
|
"text": "the"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
259,
|
"id": 259,
|
||||||
"",
|
"logprob": -0.7219573,
|
||||||
-0.7219573
|
"text": ""
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
49076,
|
"id": 49076,
|
||||||
"contents",
|
"logprob": -1.0494149,
|
||||||
-1.0494149
|
"text": "contents"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
260,
|
"id": 260,
|
||||||
".",
|
"logprob": -1.0803378,
|
||||||
-1.0803378
|
"text": "."
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
259,
|
"id": 259,
|
||||||
"",
|
"logprob": -0.32933083,
|
||||||
-0.32933083
|
"text": ""
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
215100,
|
"id": 215100,
|
||||||
"\"\"\"",
|
"logprob": -0.11268901,
|
||||||
-0.11268901
|
"text": "\"\"\""
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
2978,
|
"id": 2978,
|
||||||
"test",
|
"logprob": -1.5846587,
|
||||||
-1.5846587
|
"text": "test"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
290,
|
"id": 290,
|
||||||
"_",
|
"logprob": -0.49796978,
|
||||||
-0.49796978
|
"text": "_"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
4125,
|
"id": 4125,
|
||||||
"test",
|
"logprob": -2.0026445,
|
||||||
-2.0026445
|
"text": "test"
|
||||||
]
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test"
|
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test"
|
||||||
}
|
}
|
||||||
]
|
|
@ -22,7 +22,12 @@ pub(crate) struct GenerateParameters {
|
|||||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null")]
|
#[schema(exclusive_minimum = 0, nullable = true, default = "null")]
|
||||||
pub top_k: Option<i32>,
|
pub top_k: Option<i32>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(exclusive_minimum = 0.0, maximum = 1.0, nullable = true, default = "null")]
|
#[schema(
|
||||||
|
exclusive_minimum = 0.0,
|
||||||
|
maximum = 1.0,
|
||||||
|
nullable = true,
|
||||||
|
default = "null"
|
||||||
|
)]
|
||||||
pub top_p: Option<f32>,
|
pub top_p: Option<f32>,
|
||||||
#[serde(default = "default_do_sample")]
|
#[serde(default = "default_do_sample")]
|
||||||
#[schema(default = "false")]
|
#[schema(default = "false")]
|
||||||
@ -31,7 +36,7 @@ pub(crate) struct GenerateParameters {
|
|||||||
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
|
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
|
||||||
pub max_new_tokens: u32,
|
pub max_new_tokens: u32,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(max_items = 4, default = "null")]
|
#[schema(inline, max_items = 4, example = json!(["photographer"]))]
|
||||||
pub stop: Vec<String>,
|
pub stop: Vec<String>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(default = "true")]
|
#[schema(default = "true")]
|
||||||
@ -72,22 +77,33 @@ pub(crate) struct GenerateRequest {
|
|||||||
|
|
||||||
#[derive(Debug, Serialize, ToSchema)]
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
pub struct Token {
|
pub struct Token {
|
||||||
|
#[schema(example = 0)]
|
||||||
id: u32,
|
id: u32,
|
||||||
|
#[schema(example = "test")]
|
||||||
text: String,
|
text: String,
|
||||||
|
#[schema(nullable = true, example = -0.34)]
|
||||||
logprob: f32,
|
logprob: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
|
#[serde(rename_all(serialize = "snake_case"))]
|
||||||
pub(crate) enum FinishReason {
|
pub(crate) enum FinishReason {
|
||||||
|
#[schema(rename = "length")]
|
||||||
Length,
|
Length,
|
||||||
|
#[serde(rename = "eos_token")]
|
||||||
|
#[schema(rename = "eos_token")]
|
||||||
EndOfSequenceToken,
|
EndOfSequenceToken,
|
||||||
|
#[schema(rename = "stop_sequence")]
|
||||||
StopSequence,
|
StopSequence,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct Details {
|
pub(crate) struct Details {
|
||||||
|
#[schema(example = "length")]
|
||||||
pub finish_reason: FinishReason,
|
pub finish_reason: FinishReason,
|
||||||
|
#[schema(example = 1)]
|
||||||
pub generated_tokens: u32,
|
pub generated_tokens: u32,
|
||||||
|
#[schema(example = 42)]
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
pub prefill: Option<Vec<Token>>,
|
pub prefill: Option<Vec<Token>>,
|
||||||
pub tokens: Option<Vec<Token>>,
|
pub tokens: Option<Vec<Token>>,
|
||||||
@ -95,6 +111,7 @@ pub(crate) struct Details {
|
|||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct GenerateResponse {
|
pub(crate) struct GenerateResponse {
|
||||||
|
#[schema(example = "test")]
|
||||||
pub generated_text: String,
|
pub generated_text: String,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub details: Option<Details>,
|
pub details: Option<Details>,
|
||||||
@ -102,31 +119,25 @@ pub(crate) struct GenerateResponse {
|
|||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct StreamDetails {
|
pub(crate) struct StreamDetails {
|
||||||
|
#[schema(example = "length")]
|
||||||
pub finish_reason: FinishReason,
|
pub finish_reason: FinishReason,
|
||||||
|
#[schema(example = 1)]
|
||||||
pub generated_tokens: u32,
|
pub generated_tokens: u32,
|
||||||
|
#[schema(example = 42)]
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct StreamResponse {
|
pub(crate) struct StreamResponse {
|
||||||
pub token: Token,
|
pub token: Token,
|
||||||
|
#[schema(nullable = true, default = "null", example = "test")]
|
||||||
pub generated_text: Option<String>,
|
pub generated_text: Option<String>,
|
||||||
|
#[schema(nullable = true, default = "null")]
|
||||||
pub details: Option<StreamDetails>,
|
pub details: Option<StreamDetails>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
|
||||||
pub(crate) enum ErrorType {
|
|
||||||
#[schema(example = "Request failed during generation")]
|
|
||||||
GenerationError(String),
|
|
||||||
#[schema(example = "Model is overloaded")]
|
|
||||||
Overloaded(String),
|
|
||||||
#[schema(example = "Input validation error")]
|
|
||||||
ValidationError(String),
|
|
||||||
#[schema(example = "Incomplete generation")]
|
|
||||||
IncompleteGeneration(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct ErrorResponse {
|
pub(crate) struct ErrorResponse {
|
||||||
pub error: ErrorType,
|
#[schema(inline)]
|
||||||
|
pub error: String,
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
/// HTTP Server logic
|
/// HTTP Server logic
|
||||||
use crate::infer::{InferError, InferStreamResponse};
|
use crate::infer::{InferError, InferStreamResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
Details, ErrorResponse, ErrorType, FinishReason, GenerateParameters, GenerateRequest,
|
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
||||||
GenerateResponse, Infer, StreamDetails, StreamResponse, Token, Validation,
|
GenerateResponse, Infer, StreamDetails, StreamResponse, Token, Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
@ -49,8 +49,24 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate method
|
/// Generate tokens
|
||||||
#[utoipa::path(post, path = "/generate", request_body = GenerateRequest)]
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/generate",
|
||||||
|
request_body = GenerateRequest,
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Generated Text", body = [GenerateResponse]),
|
||||||
|
(status = 424, description = "Generation Error", body = [ErrorResponse],
|
||||||
|
example = json!({"error": "Request failed during generation"})),
|
||||||
|
(status = 429, description = "Model is overloaded", body = [ErrorResponse],
|
||||||
|
example = json!({"error": "Model is overloaded"})),
|
||||||
|
(status = 422, description = "Input validation error", body = [ErrorResponse],
|
||||||
|
example = json!({"error": "Input validation error"})),
|
||||||
|
(status = 500, description = "Incomplete generation", body = [ErrorResponse],
|
||||||
|
example = json!({"error": "Incomplete generation"})),
|
||||||
|
)
|
||||||
|
)]
|
||||||
#[instrument(
|
#[instrument(
|
||||||
skip(infer),
|
skip(infer),
|
||||||
fields(
|
fields(
|
||||||
@ -135,9 +151,29 @@ async fn generate(
|
|||||||
Ok((headers, Json(response)))
|
Ok((headers, Json(response)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate stream method
|
/// Generate a stream of token using Server Side Events
|
||||||
#[utoipa::path(post, path = "/generate_stream")]
|
#[utoipa::path(
|
||||||
#[instrument(
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/generate_stream",
|
||||||
|
request_body = GenerateRequest,
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Generated Text", body = [StreamResponse],
|
||||||
|
content_type="text/event-stream "),
|
||||||
|
(status = 424, description = "Generation Error", body = [ErrorResponse],
|
||||||
|
example = json!({"error": "Request failed during generation"}),
|
||||||
|
content_type="text/event-stream "),
|
||||||
|
(status = 429, description = "Model is overloaded", body = [ErrorResponse],
|
||||||
|
example = json!({"error": "Model is overloaded"}),
|
||||||
|
content_type="text/event-stream "),
|
||||||
|
(status = 422, description = "Input validation error", body = [ErrorResponse],
|
||||||
|
example = json!({"error": "Input validation error"}),
|
||||||
|
content_type="text/event-stream "),
|
||||||
|
(status = 500, description = "Incomplete generation", body = [ErrorResponse],
|
||||||
|
example = json!({"error": "Incomplete generation"}),
|
||||||
|
content_type="text/event-stream "),
|
||||||
|
)
|
||||||
|
)]#[instrument(
|
||||||
skip(infer),
|
skip(infer),
|
||||||
fields(
|
fields(
|
||||||
total_time,
|
total_time,
|
||||||
@ -285,11 +321,17 @@ pub async fn run(
|
|||||||
StreamResponse,
|
StreamResponse,
|
||||||
StreamDetails,
|
StreamDetails,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
ErrorType
|
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
tags(
|
tags(
|
||||||
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
|
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
|
||||||
|
),
|
||||||
|
info(
|
||||||
|
title = "Text Generation Inference",
|
||||||
|
license(
|
||||||
|
name = "Apache 2.0",
|
||||||
|
url = "https://www.apache.org/licenses/LICENSE-2.0"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
struct ApiDoc;
|
struct ApiDoc;
|
||||||
@ -361,19 +403,6 @@ impl From<i32> for FinishReason {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<InferError> for ErrorResponse {
|
|
||||||
fn from(err: InferError) -> Self {
|
|
||||||
let err_string = err.to_string();
|
|
||||||
let error = match err {
|
|
||||||
InferError::GenerationError(_) => ErrorType::GenerationError(err_string),
|
|
||||||
InferError::Overloaded(_) => ErrorType::Overloaded(err_string),
|
|
||||||
InferError::ValidationError(_) => ErrorType::ValidationError(err_string),
|
|
||||||
InferError::IncompleteGeneration => ErrorType::IncompleteGeneration(err_string),
|
|
||||||
};
|
|
||||||
ErrorResponse { error }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert to Axum supported formats
|
/// Convert to Axum supported formats
|
||||||
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
||||||
fn from(err: InferError) -> Self {
|
fn from(err: InferError) -> Self {
|
||||||
@ -384,14 +413,14 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
|||||||
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
};
|
};
|
||||||
|
|
||||||
(status_code, Json(ErrorResponse::from(err)))
|
(status_code, Json(ErrorResponse{error: err.to_string()}))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<InferError> for Event {
|
impl From<InferError> for Event {
|
||||||
fn from(err: InferError) -> Self {
|
fn from(err: InferError) -> Self {
|
||||||
Event::default()
|
Event::default()
|
||||||
.json_data(ErrorResponse::from(err))
|
.json_data(ErrorResponse{error: err.to_string()})
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user