mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
feat(router): add openAPI schemas
This commit is contained in:
parent
b1482d9048
commit
2878c43cc5
183
Cargo.lock
generated
183
Cargo.lock
generated
@ -83,9 +83,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
||||
|
||||
[[package]]
|
||||
name = "axum"
|
||||
version = "0.5.17"
|
||||
version = "0.6.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "acee9fd5073ab6b045a275b3e709c163dd36c90685219cb21804a147b58dba43"
|
||||
checksum = "e5694b64066a2459918d8074c2ce0d5a88f409431994c2356617c8ae0c4721fc"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum-core",
|
||||
@ -101,8 +101,10 @@ dependencies = [
|
||||
"mime",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"rustversion",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_path_to_error",
|
||||
"serde_urlencoded",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
@ -114,9 +116,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "axum-core"
|
||||
version = "0.2.9"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37e5939e02c56fecd5c017c37df4238c0a839fa76b7f97acdd7efb804fd181cc"
|
||||
checksum = "1cae3e661676ffbacb30f1a824089a8c9150e71017f7e1e38f2aa32009188d34"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
@ -124,6 +126,7 @@ dependencies = [
|
||||
"http",
|
||||
"http-body",
|
||||
"mime",
|
||||
"rustversion",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
@ -207,7 +210,7 @@ dependencies = [
|
||||
"tar",
|
||||
"tempfile",
|
||||
"thiserror",
|
||||
"zip",
|
||||
"zip 0.5.13",
|
||||
"zip-extensions",
|
||||
]
|
||||
|
||||
@ -465,6 +468,15 @@ dependencies = [
|
||||
"dirs-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dirs"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059"
|
||||
dependencies = [
|
||||
"dirs-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dirs-sys"
|
||||
version = "0.3.7"
|
||||
@ -867,6 +879,7 @@ checksum = "10a35a97730320ffe8e2d410b5d3b69279b98d2c14bdb8b70ea89ecf7888d41e"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"hashbrown",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -999,9 +1012,9 @@ checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d"
|
||||
|
||||
[[package]]
|
||||
name = "matchit"
|
||||
version = "0.5.0"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb"
|
||||
checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40"
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
@ -1024,6 +1037,16 @@ version = "0.3.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d"
|
||||
|
||||
[[package]]
|
||||
name = "mime_guess"
|
||||
version = "2.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef"
|
||||
dependencies = [
|
||||
"mime",
|
||||
"unicase",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "minimal-lexical"
|
||||
version = "0.2.1"
|
||||
@ -1552,12 +1575,62 @@ dependencies = [
|
||||
"winreg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust-embed"
|
||||
version = "6.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "283ffe2f866869428c92e0d61c2f35dfb4355293cdfdc48f49e895c15f1333d1"
|
||||
dependencies = [
|
||||
"rust-embed-impl",
|
||||
"rust-embed-utils",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust-embed-impl"
|
||||
version = "6.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "31ab23d42d71fb9be1b643fe6765d292c5e14d46912d13f3ae2815ca048ea04d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rust-embed-utils",
|
||||
"shellexpand",
|
||||
"syn",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust-embed-utils"
|
||||
version = "7.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c1669d81dfabd1b5f8e2856b8bbe146c6192b0ba22162edc738ac0a5de18f054"
|
||||
dependencies = [
|
||||
"sha2",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5583e89e108996506031660fe09baa5011b9dd0341b89029313006d1fb508d70"
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09"
|
||||
|
||||
[[package]]
|
||||
name = "same-file"
|
||||
version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
|
||||
dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "schannel"
|
||||
version = "0.1.20"
|
||||
@ -1628,6 +1701,15 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_path_to_error"
|
||||
version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26b04f22b563c91331a10074bda3dd5492e3cc39d56bd557e91c0af42b6c7341"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_urlencoded"
|
||||
version = "0.7.1"
|
||||
@ -1660,6 +1742,15 @@ dependencies = [
|
||||
"lazy_static",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "shellexpand"
|
||||
version = "2.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ccc8076840c4da029af4f87e4e8daeb0fca6b87bbb02e10cb60b791450e11e4"
|
||||
dependencies = [
|
||||
"dirs 4.0.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "signal-hook-registry"
|
||||
version = "1.4.0"
|
||||
@ -1845,6 +1936,8 @@ dependencies = [
|
||||
"tokio-stream",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"utoipa",
|
||||
"utoipa-swagger-ui",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1921,7 +2014,7 @@ dependencies = [
|
||||
"cached-path",
|
||||
"clap 2.34.0",
|
||||
"derive_builder",
|
||||
"dirs",
|
||||
"dirs 3.0.2",
|
||||
"esaxx-rs",
|
||||
"getrandom",
|
||||
"indicatif 0.15.0",
|
||||
@ -2234,6 +2327,15 @@ version = "1.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
|
||||
|
||||
[[package]]
|
||||
name = "unicase"
|
||||
version = "2.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6"
|
||||
dependencies = [
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-bidi"
|
||||
version = "0.3.8"
|
||||
@ -2293,6 +2395,46 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "utoipa"
|
||||
version = "3.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3920fa753064b1be7842bea26175ffa0dfc4a8f30bcb52b8ff03fddf8889914c"
|
||||
dependencies = [
|
||||
"indexmap",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"utoipa-gen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "utoipa-gen"
|
||||
version = "3.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "720298fac6efca20df9e457e67a1eab41a20d1c3101380b5c4dca1ca60ae0062"
|
||||
dependencies = [
|
||||
"proc-macro-error",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "utoipa-swagger-ui"
|
||||
version = "3.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ae3d4f4da6408f0f20ff58196ed619c94306ab32635aeca3d3fa0768c0bd0de2"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"mime_guess",
|
||||
"regex",
|
||||
"rust-embed",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"utoipa",
|
||||
"zip 0.6.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "valuable"
|
||||
version = "0.1.0"
|
||||
@ -2317,6 +2459,17 @@ version = "0.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||
|
||||
[[package]]
|
||||
name = "walkdir"
|
||||
version = "2.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56"
|
||||
dependencies = [
|
||||
"same-file",
|
||||
"winapi",
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "want"
|
||||
version = "0.3.0"
|
||||
@ -2589,11 +2742,23 @@ dependencies = [
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zip"
|
||||
version = "0.6.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0445d0fbc924bb93539b4316c11afb121ea39296f99a3c4c9edad09e3658cdef"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"crc32fast",
|
||||
"crossbeam-utils",
|
||||
"flate2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zip-extensions"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a64c3c977bc3434ce2d4bcea8ad3c644672de0f2c402b72b9171ca80a8885d14"
|
||||
dependencies = [
|
||||
"zip",
|
||||
"zip 0.5.13",
|
||||
]
|
||||
|
@ -71,13 +71,19 @@ message Batch {
|
||||
uint32 size = 3;
|
||||
}
|
||||
|
||||
enum FinishReason {
|
||||
FINISH_REASON_LENGTH = 0;
|
||||
FINISH_REASON_EOS_TOKEN = 1;
|
||||
FINISH_REASON_STOP_SEQUENCE = 2;
|
||||
}
|
||||
|
||||
message GeneratedText {
|
||||
/// Output
|
||||
string text = 1;
|
||||
/// Number of generated tokens
|
||||
uint32 generated_tokens = 2;
|
||||
/// Finish reason
|
||||
string finish_reason = 3;
|
||||
FinishReason finish_reason = 3;
|
||||
/// Seed
|
||||
optional uint64 seed = 4;
|
||||
}
|
||||
|
@ -14,7 +14,7 @@ path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
async-stream = "0.3.3"
|
||||
axum = { version = "0.5.16", features = ["json", "serde_json"] }
|
||||
axum = { version = "0.6.4", features = ["json"] }
|
||||
text-generation-client = { path = "client" }
|
||||
clap = { version = "4.0.15", features = ["derive", "env"] }
|
||||
futures = "0.3.24"
|
||||
@ -29,4 +29,6 @@ tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot"
|
||||
tokio-stream = "0.1.11"
|
||||
tracing = "0.1.36"
|
||||
tracing-subscriber = { version = "0.3.15", features = ["json"] }
|
||||
utoipa = { version = "3.0.1", features = ["axum_extras"] }
|
||||
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
|
||||
|
||||
|
@ -7,8 +7,8 @@ mod sharded_client;
|
||||
|
||||
pub use client::Client;
|
||||
pub use pb::generate::v1::{
|
||||
Batch, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, Request,
|
||||
StoppingCriteriaParameters,
|
||||
Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens,
|
||||
Request, StoppingCriteriaParameters,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
use thiserror::Error;
|
||||
|
@ -127,7 +127,7 @@ impl Infer {
|
||||
.into_iter()
|
||||
.zip(tokens.logprobs.into_iter())
|
||||
.zip(tokens.texts.into_iter())
|
||||
.map(|((id, logprob), text)| Token(id, text, logprob))
|
||||
.map(|((id, logprob), text)| Token { id, text, logprob })
|
||||
.collect();
|
||||
}
|
||||
// Push last token
|
||||
@ -282,11 +282,11 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
|
||||
}
|
||||
|
||||
// Create last Token
|
||||
let token = Token(
|
||||
generation.token_id,
|
||||
generation.token_text,
|
||||
generation.token_logprob,
|
||||
);
|
||||
let token = Token {
|
||||
id: generation.token_id,
|
||||
text: generation.token_text,
|
||||
logprob: generation.token_logprob,
|
||||
};
|
||||
|
||||
if let Some(generated_text) = generation.generated_text {
|
||||
// Remove entry as this is the last message
|
||||
|
@ -1,5 +1,4 @@
|
||||
/// Text Generation Inference Webserver
|
||||
|
||||
mod infer;
|
||||
mod queue;
|
||||
pub mod server;
|
||||
@ -8,45 +7,39 @@ mod validation;
|
||||
use infer::Infer;
|
||||
use queue::{Entry, Queue};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utoipa::ToSchema;
|
||||
use validation::Validation;
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||
pub(crate) struct GenerateParameters {
|
||||
#[serde(default = "default_temperature")]
|
||||
pub temperature: f32,
|
||||
#[serde(default = "default_repetition_penalty")]
|
||||
pub repetition_penalty: f32,
|
||||
#[serde(default = "default_top_k")]
|
||||
pub top_k: i32,
|
||||
#[serde(default = "default_top_p")]
|
||||
pub top_p: f32,
|
||||
#[serde(default)]
|
||||
#[schema(exclusive_minimum = 0.0, nullable = true, default = "null")]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(default)]
|
||||
#[schema(exclusive_minimum = 0.0, nullable = true, default = "null")]
|
||||
pub repetition_penalty: Option<f32>,
|
||||
#[serde(default)]
|
||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null")]
|
||||
pub top_k: Option<i32>,
|
||||
#[serde(default)]
|
||||
#[schema(exclusive_minimum = 0.0, maximum = 1.0, nullable = true, default = "null")]
|
||||
pub top_p: Option<f32>,
|
||||
#[serde(default = "default_do_sample")]
|
||||
#[schema(default = "false")]
|
||||
pub do_sample: bool,
|
||||
#[serde(default = "default_max_new_tokens")]
|
||||
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
|
||||
pub max_new_tokens: u32,
|
||||
#[serde(default)]
|
||||
#[schema(max_items = 4, default = "null")]
|
||||
pub stop: Vec<String>,
|
||||
#[serde(default)]
|
||||
#[schema(default = "true")]
|
||||
pub details: bool,
|
||||
#[serde(default)]
|
||||
pub seed: Option<u64>,
|
||||
}
|
||||
|
||||
fn default_temperature() -> f32 {
|
||||
1.0
|
||||
}
|
||||
fn default_repetition_penalty() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
fn default_top_k() -> i32 {
|
||||
0
|
||||
}
|
||||
|
||||
fn default_top_p() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
fn default_do_sample() -> bool {
|
||||
false
|
||||
}
|
||||
@ -57,10 +50,10 @@ fn default_max_new_tokens() -> u32 {
|
||||
|
||||
fn default_parameters() -> GenerateParameters {
|
||||
GenerateParameters {
|
||||
temperature: default_temperature(),
|
||||
repetition_penalty: default_repetition_penalty(),
|
||||
top_k: default_top_k(),
|
||||
top_p: default_top_p(),
|
||||
temperature: None,
|
||||
repetition_penalty: None,
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
do_sample: default_do_sample(),
|
||||
max_new_tokens: default_max_new_tokens(),
|
||||
stop: vec![],
|
||||
@ -69,42 +62,71 @@ fn default_parameters() -> GenerateParameters {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||
pub(crate) struct GenerateRequest {
|
||||
#[schema(example = "My name is Olivier and I")]
|
||||
pub inputs: String,
|
||||
#[serde(default = "default_parameters")]
|
||||
pub parameters: GenerateParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct Token(u32, String, f32);
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct Token {
|
||||
id: u32,
|
||||
text: String,
|
||||
logprob: f32,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[derive(Serialize, ToSchema)]
|
||||
pub(crate) enum FinishReason {
|
||||
Length,
|
||||
EndOfSequenceToken,
|
||||
StopSequence,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
pub(crate) struct Details {
|
||||
pub finish_reason: String,
|
||||
pub finish_reason: FinishReason,
|
||||
pub generated_tokens: u32,
|
||||
pub seed: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prefill: Option<Vec<Token>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tokens: Option<Vec<Token>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[derive(Serialize, ToSchema)]
|
||||
pub(crate) struct GenerateResponse {
|
||||
pub generated_text: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub details: Option<Details>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[derive(Serialize, ToSchema)]
|
||||
pub(crate) struct StreamDetails {
|
||||
pub finish_reason: FinishReason,
|
||||
pub generated_tokens: u32,
|
||||
pub seed: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
pub(crate) struct StreamResponse {
|
||||
pub token: Token,
|
||||
pub generated_text: Option<String>,
|
||||
pub details: Option<Details>,
|
||||
pub details: Option<StreamDetails>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub(crate) struct ErrorResponse {
|
||||
pub error: String,
|
||||
#[derive(Serialize, ToSchema)]
|
||||
pub(crate) enum ErrorType {
|
||||
#[schema(example = "Request failed during generation")]
|
||||
GenerationError(String),
|
||||
#[schema(example = "Model is overloaded")]
|
||||
Overloaded(String),
|
||||
#[schema(example = "Input validation error")]
|
||||
ValidationError(String),
|
||||
#[schema(example = "Incomplete generation")]
|
||||
IncompleteGeneration(String),
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
pub(crate) struct ErrorResponse {
|
||||
pub error: ErrorType,
|
||||
}
|
||||
|
@ -1,8 +1,8 @@
|
||||
/// HTTP Server logic
|
||||
use crate::infer::{InferError, InferStreamResponse};
|
||||
use crate::{
|
||||
Details, ErrorResponse, GenerateParameters, GenerateRequest, GenerateResponse, Infer,
|
||||
StreamResponse, Validation,
|
||||
Details, ErrorResponse, ErrorType, FinishReason, GenerateParameters, GenerateRequest,
|
||||
GenerateResponse, Infer, StreamDetails, StreamResponse, Token, Validation,
|
||||
};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, StatusCode};
|
||||
@ -19,6 +19,8 @@ use tokio::signal;
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::instrument;
|
||||
use utoipa::OpenApi;
|
||||
use utoipa_swagger_ui::SwaggerUi;
|
||||
|
||||
/// Health check method
|
||||
#[instrument(skip(infer))]
|
||||
@ -32,13 +34,13 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
||||
.generate(GenerateRequest {
|
||||
inputs: "liveness".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
temperature: 1.0,
|
||||
repetition_penalty: 1.0,
|
||||
top_k: 0,
|
||||
top_p: 1.0,
|
||||
temperature: None,
|
||||
repetition_penalty: None,
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
do_sample: false,
|
||||
max_new_tokens: 1,
|
||||
stop: vec![],
|
||||
stop: Vec::new(),
|
||||
details: false,
|
||||
seed: None,
|
||||
},
|
||||
@ -48,6 +50,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
||||
}
|
||||
|
||||
/// Generate method
|
||||
#[utoipa::path(post, path = "/generate", request_body = GenerateRequest)]
|
||||
#[instrument(
|
||||
skip(infer),
|
||||
fields(
|
||||
@ -76,7 +79,7 @@ async fn generate(
|
||||
// Token details
|
||||
let details = match details {
|
||||
true => Some(Details {
|
||||
finish_reason: response.generated_text.finish_reason,
|
||||
finish_reason: FinishReason::from(response.generated_text.finish_reason),
|
||||
generated_tokens: response.generated_text.generated_tokens,
|
||||
prefill: Some(response.prefill),
|
||||
tokens: Some(response.tokens),
|
||||
@ -133,6 +136,7 @@ async fn generate(
|
||||
}
|
||||
|
||||
/// Generate stream method
|
||||
#[utoipa::path(post, path = "/generate_stream")]
|
||||
#[instrument(
|
||||
skip(infer),
|
||||
fields(
|
||||
@ -185,11 +189,9 @@ async fn generate_stream(
|
||||
} => {
|
||||
// Token details
|
||||
let details = match details {
|
||||
true => Some(Details {
|
||||
finish_reason: generated_text.finish_reason,
|
||||
true => Some(StreamDetails {
|
||||
finish_reason: FinishReason::from(generated_text.finish_reason),
|
||||
generated_tokens: generated_text.generated_tokens,
|
||||
prefill: None,
|
||||
tokens: None,
|
||||
seed: generated_text.seed,
|
||||
}),
|
||||
false => None,
|
||||
@ -265,6 +267,33 @@ pub async fn run(
|
||||
validation_workers: usize,
|
||||
addr: SocketAddr,
|
||||
) {
|
||||
// OpenAPI documentation
|
||||
#[derive(OpenApi)]
|
||||
#[openapi(
|
||||
paths(
|
||||
generate,
|
||||
generate_stream,
|
||||
),
|
||||
components(
|
||||
schemas(
|
||||
GenerateRequest,
|
||||
GenerateParameters,
|
||||
Token,
|
||||
GenerateResponse,
|
||||
Details,
|
||||
FinishReason,
|
||||
StreamResponse,
|
||||
StreamDetails,
|
||||
ErrorResponse,
|
||||
ErrorType
|
||||
)
|
||||
),
|
||||
tags(
|
||||
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
|
||||
)
|
||||
)]
|
||||
struct ApiDoc;
|
||||
|
||||
// Create state
|
||||
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
||||
let infer = Infer::new(
|
||||
@ -277,6 +306,7 @@ pub async fn run(
|
||||
|
||||
// Create router
|
||||
let app = Router::new()
|
||||
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
||||
.route("/", post(generate))
|
||||
.route("/generate", post(generate))
|
||||
.route("/generate_stream", post(generate_stream))
|
||||
@ -320,6 +350,30 @@ async fn shutdown_signal() {
|
||||
tracing::info!("signal received, starting graceful shutdown");
|
||||
}
|
||||
|
||||
impl From<i32> for FinishReason {
|
||||
fn from(finish_reason: i32) -> Self {
|
||||
let finish_reason = text_generation_client::FinishReason::from_i32(finish_reason).unwrap();
|
||||
match finish_reason {
|
||||
text_generation_client::FinishReason::Length => FinishReason::Length,
|
||||
text_generation_client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||
text_generation_client::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<InferError> for ErrorResponse {
|
||||
fn from(err: InferError) -> Self {
|
||||
let err_string = err.to_string();
|
||||
let error = match err {
|
||||
InferError::GenerationError(_) => ErrorType::GenerationError(err_string),
|
||||
InferError::Overloaded(_) => ErrorType::Overloaded(err_string),
|
||||
InferError::ValidationError(_) => ErrorType::ValidationError(err_string),
|
||||
InferError::IncompleteGeneration => ErrorType::IncompleteGeneration(err_string),
|
||||
};
|
||||
ErrorResponse { error }
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to Axum supported formats
|
||||
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
||||
fn from(err: InferError) -> Self {
|
||||
@ -330,21 +384,14 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
||||
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
};
|
||||
|
||||
(
|
||||
status_code,
|
||||
Json(ErrorResponse {
|
||||
error: err.to_string(),
|
||||
}),
|
||||
)
|
||||
(status_code, Json(ErrorResponse::from(err)))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<InferError> for Event {
|
||||
fn from(err: InferError) -> Self {
|
||||
Event::default()
|
||||
.json_data(ErrorResponse {
|
||||
error: err.to_string(),
|
||||
})
|
||||
.json_data(ErrorResponse::from(err))
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
@ -110,30 +110,58 @@ fn validate(
|
||||
max_input_length: usize,
|
||||
rng: &mut ThreadRng,
|
||||
) -> Result<ValidGenerateRequest, ValidationError> {
|
||||
if request.parameters.temperature <= 0.0 {
|
||||
let GenerateParameters {
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
top_k,
|
||||
top_p,
|
||||
do_sample,
|
||||
max_new_tokens,
|
||||
stop: stop_sequences,
|
||||
seed,
|
||||
..
|
||||
} = request.parameters;
|
||||
|
||||
let temperature = temperature.unwrap_or(1.0);
|
||||
if temperature <= 0.0 {
|
||||
return Err(ValidationError::Temperature);
|
||||
}
|
||||
if request.parameters.repetition_penalty <= 0.0 {
|
||||
|
||||
let repetition_penalty = repetition_penalty.unwrap_or(1.0);
|
||||
if repetition_penalty <= 0.0 {
|
||||
return Err(ValidationError::RepetitionPenalty);
|
||||
}
|
||||
if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 {
|
||||
|
||||
let top_p = top_p.unwrap_or(1.0);
|
||||
if top_p <= 0.0 || top_p > 1.0 {
|
||||
return Err(ValidationError::TopP);
|
||||
}
|
||||
if request.parameters.top_k < 0 {
|
||||
return Err(ValidationError::TopK);
|
||||
}
|
||||
if request.parameters.max_new_tokens > MAX_MAX_NEW_TOKENS {
|
||||
|
||||
// Different because the proto default value is 0 while it is not a valid value
|
||||
// for the user
|
||||
let top_k: u32 = match top_k {
|
||||
None => Ok(0),
|
||||
Some(top_k) => {
|
||||
if top_k <= 0 {
|
||||
return Err(ValidationError::TopK);
|
||||
}
|
||||
Ok(top_k as u32)
|
||||
}
|
||||
}?;
|
||||
|
||||
if max_new_tokens <= 0 || max_new_tokens > MAX_MAX_NEW_TOKENS {
|
||||
return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS));
|
||||
}
|
||||
if request.parameters.stop.len() > MAX_STOP_SEQUENCES {
|
||||
|
||||
if stop_sequences.len() > MAX_STOP_SEQUENCES {
|
||||
return Err(ValidationError::StopSequence(
|
||||
MAX_STOP_SEQUENCES,
|
||||
request.parameters.stop.len(),
|
||||
stop_sequences.len(),
|
||||
));
|
||||
}
|
||||
|
||||
// If seed is None, assign a random one
|
||||
let seed = match request.parameters.seed {
|
||||
let seed = match seed {
|
||||
None => rng.gen(),
|
||||
Some(seed) => seed,
|
||||
};
|
||||
@ -147,21 +175,10 @@ fn validate(
|
||||
Err(ValidationError::InputLength(input_length, max_input_length))
|
||||
} else {
|
||||
// Return ValidGenerateRequest
|
||||
let GenerateParameters {
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
top_k,
|
||||
top_p,
|
||||
do_sample,
|
||||
max_new_tokens,
|
||||
stop: stop_sequences,
|
||||
..
|
||||
} = request.parameters;
|
||||
|
||||
let parameters = NextTokenChooserParameters {
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
top_k: top_k as u32,
|
||||
top_k,
|
||||
top_p,
|
||||
do_sample,
|
||||
seed,
|
||||
@ -206,7 +223,7 @@ pub enum ValidationError {
|
||||
TopP,
|
||||
#[error("top_k must be strictly positive")]
|
||||
TopK,
|
||||
#[error("max_new_tokens must be <= {0}")]
|
||||
#[error("max_new_tokens must be strictly positive and <= {0}")]
|
||||
MaxNewTokens(u32),
|
||||
#[error("inputs must have less than {1} tokens. Given: {0}")]
|
||||
InputLength(usize, usize),
|
||||
|
@ -9,6 +9,7 @@ from text_generation.utils import (
|
||||
StopSequenceCriteria,
|
||||
StoppingCriteria,
|
||||
LocalEntryNotFoundError,
|
||||
FinishReason
|
||||
)
|
||||
|
||||
|
||||
@ -24,13 +25,13 @@ def test_stop_sequence_criteria():
|
||||
def test_stopping_criteria():
|
||||
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
||||
assert criteria(65827, "/test") == (False, None)
|
||||
assert criteria(30, ";") == (True, "stop_sequence")
|
||||
assert criteria(30, ";") == (True, FinishReason.FINISH_REASON_STOP_SEQUENCE)
|
||||
|
||||
|
||||
def test_stopping_criteria_eos():
|
||||
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
||||
assert criteria(1, "") == (False, None)
|
||||
assert criteria(0, "") == (True, "eos_token")
|
||||
assert criteria(0, "") == (True, FinishReason.FINISH_REASON_EOS_TOKEN)
|
||||
|
||||
|
||||
def test_stopping_criteria_max():
|
||||
@ -39,7 +40,7 @@ def test_stopping_criteria_max():
|
||||
assert criteria(1, "") == (False, None)
|
||||
assert criteria(1, "") == (False, None)
|
||||
assert criteria(1, "") == (False, None)
|
||||
assert criteria(1, "") == (True, "length")
|
||||
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
|
||||
|
||||
|
||||
def test_weight_hub_files():
|
||||
|
@ -7,6 +7,7 @@ from typing import List, Optional
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.pb.generate_pb2 import FinishReason
|
||||
|
||||
|
||||
class Batch(ABC):
|
||||
@ -38,7 +39,7 @@ class Batch(ABC):
|
||||
class GeneratedText:
|
||||
text: str
|
||||
generated_tokens: int
|
||||
finish_reason: str
|
||||
finish_reason: FinishReason
|
||||
seed: Optional[int]
|
||||
|
||||
def to_pb(self) -> generate_pb2.GeneratedText:
|
||||
|
@ -24,6 +24,7 @@ from transformers.generation.logits_process import (
|
||||
)
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.pb.generate_pb2 import FinishReason
|
||||
|
||||
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
||||
|
||||
@ -129,15 +130,15 @@ class StoppingCriteria:
|
||||
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
|
||||
self.current_tokens += 1
|
||||
if self.current_tokens >= self.max_new_tokens:
|
||||
return True, "length"
|
||||
return True, FinishReason.FINISH_REASON_LENGTH
|
||||
|
||||
if last_token == self.eos_token_id:
|
||||
return True, "eos_token"
|
||||
return True, FinishReason.FINISH_REASON_EOS_TOKEN
|
||||
|
||||
self.current_output += last_output
|
||||
for stop_sequence_criteria in self.stop_sequence_criterias:
|
||||
if stop_sequence_criteria(self.current_output):
|
||||
return True, "stop_sequence"
|
||||
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
|
||||
|
||||
return False, None
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user