mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
feat(router): add openAPI schemas
This commit is contained in:
parent
b1482d9048
commit
2878c43cc5
183
Cargo.lock
generated
183
Cargo.lock
generated
@ -83,9 +83,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "axum"
|
name = "axum"
|
||||||
version = "0.5.17"
|
version = "0.6.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "acee9fd5073ab6b045a275b3e709c163dd36c90685219cb21804a147b58dba43"
|
checksum = "e5694b64066a2459918d8074c2ce0d5a88f409431994c2356617c8ae0c4721fc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum-core",
|
"axum-core",
|
||||||
@ -101,8 +101,10 @@ dependencies = [
|
|||||||
"mime",
|
"mime",
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
|
"rustversion",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"serde_path_to_error",
|
||||||
"serde_urlencoded",
|
"serde_urlencoded",
|
||||||
"sync_wrapper",
|
"sync_wrapper",
|
||||||
"tokio",
|
"tokio",
|
||||||
@ -114,9 +116,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "axum-core"
|
name = "axum-core"
|
||||||
version = "0.2.9"
|
version = "0.3.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "37e5939e02c56fecd5c017c37df4238c0a839fa76b7f97acdd7efb804fd181cc"
|
checksum = "1cae3e661676ffbacb30f1a824089a8c9150e71017f7e1e38f2aa32009188d34"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"bytes",
|
"bytes",
|
||||||
@ -124,6 +126,7 @@ dependencies = [
|
|||||||
"http",
|
"http",
|
||||||
"http-body",
|
"http-body",
|
||||||
"mime",
|
"mime",
|
||||||
|
"rustversion",
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
]
|
]
|
||||||
@ -207,7 +210,7 @@ dependencies = [
|
|||||||
"tar",
|
"tar",
|
||||||
"tempfile",
|
"tempfile",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"zip",
|
"zip 0.5.13",
|
||||||
"zip-extensions",
|
"zip-extensions",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -465,6 +468,15 @@ dependencies = [
|
|||||||
"dirs-sys",
|
"dirs-sys",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "dirs"
|
||||||
|
version = "4.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059"
|
||||||
|
dependencies = [
|
||||||
|
"dirs-sys",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dirs-sys"
|
name = "dirs-sys"
|
||||||
version = "0.3.7"
|
version = "0.3.7"
|
||||||
@ -867,6 +879,7 @@ checksum = "10a35a97730320ffe8e2d410b5d3b69279b98d2c14bdb8b70ea89ecf7888d41e"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"autocfg",
|
"autocfg",
|
||||||
"hashbrown",
|
"hashbrown",
|
||||||
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -999,9 +1012,9 @@ checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "matchit"
|
name = "matchit"
|
||||||
version = "0.5.0"
|
version = "0.7.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb"
|
checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "memchr"
|
name = "memchr"
|
||||||
@ -1024,6 +1037,16 @@ version = "0.3.16"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d"
|
checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mime_guess"
|
||||||
|
version = "2.0.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef"
|
||||||
|
dependencies = [
|
||||||
|
"mime",
|
||||||
|
"unicase",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "minimal-lexical"
|
name = "minimal-lexical"
|
||||||
version = "0.2.1"
|
version = "0.2.1"
|
||||||
@ -1552,12 +1575,62 @@ dependencies = [
|
|||||||
"winreg",
|
"winreg",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rust-embed"
|
||||||
|
version = "6.4.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "283ffe2f866869428c92e0d61c2f35dfb4355293cdfdc48f49e895c15f1333d1"
|
||||||
|
dependencies = [
|
||||||
|
"rust-embed-impl",
|
||||||
|
"rust-embed-utils",
|
||||||
|
"walkdir",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rust-embed-impl"
|
||||||
|
version = "6.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "31ab23d42d71fb9be1b643fe6765d292c5e14d46912d13f3ae2815ca048ea04d"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"rust-embed-utils",
|
||||||
|
"shellexpand",
|
||||||
|
"syn",
|
||||||
|
"walkdir",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rust-embed-utils"
|
||||||
|
version = "7.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c1669d81dfabd1b5f8e2856b8bbe146c6192b0ba22162edc738ac0a5de18f054"
|
||||||
|
dependencies = [
|
||||||
|
"sha2",
|
||||||
|
"walkdir",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustversion"
|
||||||
|
version = "1.0.11"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5583e89e108996506031660fe09baa5011b9dd0341b89029313006d1fb508d70"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ryu"
|
name = "ryu"
|
||||||
version = "1.0.11"
|
version = "1.0.11"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09"
|
checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "same-file"
|
||||||
|
version = "1.0.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
|
||||||
|
dependencies = [
|
||||||
|
"winapi-util",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "schannel"
|
name = "schannel"
|
||||||
version = "0.1.20"
|
version = "0.1.20"
|
||||||
@ -1628,6 +1701,15 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_path_to_error"
|
||||||
|
version = "0.1.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "26b04f22b563c91331a10074bda3dd5492e3cc39d56bd557e91c0af42b6c7341"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_urlencoded"
|
name = "serde_urlencoded"
|
||||||
version = "0.7.1"
|
version = "0.7.1"
|
||||||
@ -1660,6 +1742,15 @@ dependencies = [
|
|||||||
"lazy_static",
|
"lazy_static",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "shellexpand"
|
||||||
|
version = "2.1.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7ccc8076840c4da029af4f87e4e8daeb0fca6b87bbb02e10cb60b791450e11e4"
|
||||||
|
dependencies = [
|
||||||
|
"dirs 4.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "signal-hook-registry"
|
name = "signal-hook-registry"
|
||||||
version = "1.4.0"
|
version = "1.4.0"
|
||||||
@ -1845,6 +1936,8 @@ dependencies = [
|
|||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
|
"utoipa",
|
||||||
|
"utoipa-swagger-ui",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -1921,7 +2014,7 @@ dependencies = [
|
|||||||
"cached-path",
|
"cached-path",
|
||||||
"clap 2.34.0",
|
"clap 2.34.0",
|
||||||
"derive_builder",
|
"derive_builder",
|
||||||
"dirs",
|
"dirs 3.0.2",
|
||||||
"esaxx-rs",
|
"esaxx-rs",
|
||||||
"getrandom",
|
"getrandom",
|
||||||
"indicatif 0.15.0",
|
"indicatif 0.15.0",
|
||||||
@ -2234,6 +2327,15 @@ version = "1.15.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
|
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicase"
|
||||||
|
version = "2.6.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6"
|
||||||
|
dependencies = [
|
||||||
|
"version_check",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode-bidi"
|
name = "unicode-bidi"
|
||||||
version = "0.3.8"
|
version = "0.3.8"
|
||||||
@ -2293,6 +2395,46 @@ dependencies = [
|
|||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "utoipa"
|
||||||
|
version = "3.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3920fa753064b1be7842bea26175ffa0dfc4a8f30bcb52b8ff03fddf8889914c"
|
||||||
|
dependencies = [
|
||||||
|
"indexmap",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"utoipa-gen",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "utoipa-gen"
|
||||||
|
version = "3.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "720298fac6efca20df9e457e67a1eab41a20d1c3101380b5c4dca1ca60ae0062"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro-error",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "utoipa-swagger-ui"
|
||||||
|
version = "3.0.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ae3d4f4da6408f0f20ff58196ed619c94306ab32635aeca3d3fa0768c0bd0de2"
|
||||||
|
dependencies = [
|
||||||
|
"axum",
|
||||||
|
"mime_guess",
|
||||||
|
"regex",
|
||||||
|
"rust-embed",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"utoipa",
|
||||||
|
"zip 0.6.4",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "valuable"
|
name = "valuable"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
@ -2317,6 +2459,17 @@ version = "0.9.4"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "walkdir"
|
||||||
|
version = "2.3.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56"
|
||||||
|
dependencies = [
|
||||||
|
"same-file",
|
||||||
|
"winapi",
|
||||||
|
"winapi-util",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "want"
|
name = "want"
|
||||||
version = "0.3.0"
|
version = "0.3.0"
|
||||||
@ -2589,11 +2742,23 @@ dependencies = [
|
|||||||
"time",
|
"time",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zip"
|
||||||
|
version = "0.6.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0445d0fbc924bb93539b4316c11afb121ea39296f99a3c4c9edad09e3658cdef"
|
||||||
|
dependencies = [
|
||||||
|
"byteorder",
|
||||||
|
"crc32fast",
|
||||||
|
"crossbeam-utils",
|
||||||
|
"flate2",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zip-extensions"
|
name = "zip-extensions"
|
||||||
version = "0.6.1"
|
version = "0.6.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a64c3c977bc3434ce2d4bcea8ad3c644672de0f2c402b72b9171ca80a8885d14"
|
checksum = "a64c3c977bc3434ce2d4bcea8ad3c644672de0f2c402b72b9171ca80a8885d14"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"zip",
|
"zip 0.5.13",
|
||||||
]
|
]
|
||||||
|
@ -71,13 +71,19 @@ message Batch {
|
|||||||
uint32 size = 3;
|
uint32 size = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum FinishReason {
|
||||||
|
FINISH_REASON_LENGTH = 0;
|
||||||
|
FINISH_REASON_EOS_TOKEN = 1;
|
||||||
|
FINISH_REASON_STOP_SEQUENCE = 2;
|
||||||
|
}
|
||||||
|
|
||||||
message GeneratedText {
|
message GeneratedText {
|
||||||
/// Output
|
/// Output
|
||||||
string text = 1;
|
string text = 1;
|
||||||
/// Number of generated tokens
|
/// Number of generated tokens
|
||||||
uint32 generated_tokens = 2;
|
uint32 generated_tokens = 2;
|
||||||
/// Finish reason
|
/// Finish reason
|
||||||
string finish_reason = 3;
|
FinishReason finish_reason = 3;
|
||||||
/// Seed
|
/// Seed
|
||||||
optional uint64 seed = 4;
|
optional uint64 seed = 4;
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,7 @@ path = "src/main.rs"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
async-stream = "0.3.3"
|
async-stream = "0.3.3"
|
||||||
axum = { version = "0.5.16", features = ["json", "serde_json"] }
|
axum = { version = "0.6.4", features = ["json"] }
|
||||||
text-generation-client = { path = "client" }
|
text-generation-client = { path = "client" }
|
||||||
clap = { version = "4.0.15", features = ["derive", "env"] }
|
clap = { version = "4.0.15", features = ["derive", "env"] }
|
||||||
futures = "0.3.24"
|
futures = "0.3.24"
|
||||||
@ -29,4 +29,6 @@ tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot"
|
|||||||
tokio-stream = "0.1.11"
|
tokio-stream = "0.1.11"
|
||||||
tracing = "0.1.36"
|
tracing = "0.1.36"
|
||||||
tracing-subscriber = { version = "0.3.15", features = ["json"] }
|
tracing-subscriber = { version = "0.3.15", features = ["json"] }
|
||||||
|
utoipa = { version = "3.0.1", features = ["axum_extras"] }
|
||||||
|
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
|
||||||
|
|
||||||
|
@ -7,8 +7,8 @@ mod sharded_client;
|
|||||||
|
|
||||||
pub use client::Client;
|
pub use client::Client;
|
||||||
pub use pb::generate::v1::{
|
pub use pb::generate::v1::{
|
||||||
Batch, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, Request,
|
Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens,
|
||||||
StoppingCriteriaParameters,
|
Request, StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
@ -127,7 +127,7 @@ impl Infer {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.zip(tokens.logprobs.into_iter())
|
.zip(tokens.logprobs.into_iter())
|
||||||
.zip(tokens.texts.into_iter())
|
.zip(tokens.texts.into_iter())
|
||||||
.map(|((id, logprob), text)| Token(id, text, logprob))
|
.map(|((id, logprob), text)| Token { id, text, logprob })
|
||||||
.collect();
|
.collect();
|
||||||
}
|
}
|
||||||
// Push last token
|
// Push last token
|
||||||
@ -282,11 +282,11 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create last Token
|
// Create last Token
|
||||||
let token = Token(
|
let token = Token {
|
||||||
generation.token_id,
|
id: generation.token_id,
|
||||||
generation.token_text,
|
text: generation.token_text,
|
||||||
generation.token_logprob,
|
logprob: generation.token_logprob,
|
||||||
);
|
};
|
||||||
|
|
||||||
if let Some(generated_text) = generation.generated_text {
|
if let Some(generated_text) = generation.generated_text {
|
||||||
// Remove entry as this is the last message
|
// Remove entry as this is the last message
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
/// Text Generation Inference Webserver
|
/// Text Generation Inference Webserver
|
||||||
|
|
||||||
mod infer;
|
mod infer;
|
||||||
mod queue;
|
mod queue;
|
||||||
pub mod server;
|
pub mod server;
|
||||||
@ -8,45 +7,39 @@ mod validation;
|
|||||||
use infer::Infer;
|
use infer::Infer;
|
||||||
use queue::{Entry, Queue};
|
use queue::{Entry, Queue};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use utoipa::ToSchema;
|
||||||
use validation::Validation;
|
use validation::Validation;
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||||
pub(crate) struct GenerateParameters {
|
pub(crate) struct GenerateParameters {
|
||||||
#[serde(default = "default_temperature")]
|
#[serde(default)]
|
||||||
pub temperature: f32,
|
#[schema(exclusive_minimum = 0.0, nullable = true, default = "null")]
|
||||||
#[serde(default = "default_repetition_penalty")]
|
pub temperature: Option<f32>,
|
||||||
pub repetition_penalty: f32,
|
#[serde(default)]
|
||||||
#[serde(default = "default_top_k")]
|
#[schema(exclusive_minimum = 0.0, nullable = true, default = "null")]
|
||||||
pub top_k: i32,
|
pub repetition_penalty: Option<f32>,
|
||||||
#[serde(default = "default_top_p")]
|
#[serde(default)]
|
||||||
pub top_p: f32,
|
#[schema(exclusive_minimum = 0, nullable = true, default = "null")]
|
||||||
|
pub top_k: Option<i32>,
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(exclusive_minimum = 0.0, maximum = 1.0, nullable = true, default = "null")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
#[serde(default = "default_do_sample")]
|
#[serde(default = "default_do_sample")]
|
||||||
|
#[schema(default = "false")]
|
||||||
pub do_sample: bool,
|
pub do_sample: bool,
|
||||||
#[serde(default = "default_max_new_tokens")]
|
#[serde(default = "default_max_new_tokens")]
|
||||||
|
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
|
||||||
pub max_new_tokens: u32,
|
pub max_new_tokens: u32,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(max_items = 4, default = "null")]
|
||||||
pub stop: Vec<String>,
|
pub stop: Vec<String>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(default = "true")]
|
||||||
pub details: bool,
|
pub details: bool,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_temperature() -> f32 {
|
|
||||||
1.0
|
|
||||||
}
|
|
||||||
fn default_repetition_penalty() -> f32 {
|
|
||||||
1.0
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_top_k() -> i32 {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_top_p() -> f32 {
|
|
||||||
1.0
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_do_sample() -> bool {
|
fn default_do_sample() -> bool {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
@ -57,10 +50,10 @@ fn default_max_new_tokens() -> u32 {
|
|||||||
|
|
||||||
fn default_parameters() -> GenerateParameters {
|
fn default_parameters() -> GenerateParameters {
|
||||||
GenerateParameters {
|
GenerateParameters {
|
||||||
temperature: default_temperature(),
|
temperature: None,
|
||||||
repetition_penalty: default_repetition_penalty(),
|
repetition_penalty: None,
|
||||||
top_k: default_top_k(),
|
top_k: None,
|
||||||
top_p: default_top_p(),
|
top_p: None,
|
||||||
do_sample: default_do_sample(),
|
do_sample: default_do_sample(),
|
||||||
max_new_tokens: default_max_new_tokens(),
|
max_new_tokens: default_max_new_tokens(),
|
||||||
stop: vec![],
|
stop: vec![],
|
||||||
@ -69,42 +62,71 @@ fn default_parameters() -> GenerateParameters {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||||
pub(crate) struct GenerateRequest {
|
pub(crate) struct GenerateRequest {
|
||||||
|
#[schema(example = "My name is Olivier and I")]
|
||||||
pub inputs: String,
|
pub inputs: String,
|
||||||
#[serde(default = "default_parameters")]
|
#[serde(default = "default_parameters")]
|
||||||
pub parameters: GenerateParameters,
|
pub parameters: GenerateParameters,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
pub struct Token(u32, String, f32);
|
pub struct Token {
|
||||||
|
id: u32,
|
||||||
|
text: String,
|
||||||
|
logprob: f32,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize, ToSchema)]
|
||||||
|
pub(crate) enum FinishReason {
|
||||||
|
Length,
|
||||||
|
EndOfSequenceToken,
|
||||||
|
StopSequence,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct Details {
|
pub(crate) struct Details {
|
||||||
pub finish_reason: String,
|
pub finish_reason: FinishReason,
|
||||||
pub generated_tokens: u32,
|
pub generated_tokens: u32,
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub prefill: Option<Vec<Token>>,
|
pub prefill: Option<Vec<Token>>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tokens: Option<Vec<Token>>,
|
pub tokens: Option<Vec<Token>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct GenerateResponse {
|
pub(crate) struct GenerateResponse {
|
||||||
pub generated_text: String,
|
pub generated_text: String,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub details: Option<Details>,
|
pub details: Option<Details>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize, ToSchema)]
|
||||||
|
pub(crate) struct StreamDetails {
|
||||||
|
pub finish_reason: FinishReason,
|
||||||
|
pub generated_tokens: u32,
|
||||||
|
pub seed: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct StreamResponse {
|
pub(crate) struct StreamResponse {
|
||||||
pub token: Token,
|
pub token: Token,
|
||||||
pub generated_text: Option<String>,
|
pub generated_text: Option<String>,
|
||||||
pub details: Option<Details>,
|
pub details: Option<StreamDetails>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct ErrorResponse {
|
pub(crate) enum ErrorType {
|
||||||
pub error: String,
|
#[schema(example = "Request failed during generation")]
|
||||||
|
GenerationError(String),
|
||||||
|
#[schema(example = "Model is overloaded")]
|
||||||
|
Overloaded(String),
|
||||||
|
#[schema(example = "Input validation error")]
|
||||||
|
ValidationError(String),
|
||||||
|
#[schema(example = "Incomplete generation")]
|
||||||
|
IncompleteGeneration(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
|
pub(crate) struct ErrorResponse {
|
||||||
|
pub error: ErrorType,
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
/// HTTP Server logic
|
/// HTTP Server logic
|
||||||
use crate::infer::{InferError, InferStreamResponse};
|
use crate::infer::{InferError, InferStreamResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
Details, ErrorResponse, GenerateParameters, GenerateRequest, GenerateResponse, Infer,
|
Details, ErrorResponse, ErrorType, FinishReason, GenerateParameters, GenerateRequest,
|
||||||
StreamResponse, Validation,
|
GenerateResponse, Infer, StreamDetails, StreamResponse, Token, Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, StatusCode};
|
use axum::http::{HeaderMap, StatusCode};
|
||||||
@ -19,6 +19,8 @@ use tokio::signal;
|
|||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
use utoipa::OpenApi;
|
||||||
|
use utoipa_swagger_ui::SwaggerUi;
|
||||||
|
|
||||||
/// Health check method
|
/// Health check method
|
||||||
#[instrument(skip(infer))]
|
#[instrument(skip(infer))]
|
||||||
@ -32,13 +34,13 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
|||||||
.generate(GenerateRequest {
|
.generate(GenerateRequest {
|
||||||
inputs: "liveness".to_string(),
|
inputs: "liveness".to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
temperature: 1.0,
|
temperature: None,
|
||||||
repetition_penalty: 1.0,
|
repetition_penalty: None,
|
||||||
top_k: 0,
|
top_k: None,
|
||||||
top_p: 1.0,
|
top_p: None,
|
||||||
do_sample: false,
|
do_sample: false,
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
stop: vec![],
|
stop: Vec::new(),
|
||||||
details: false,
|
details: false,
|
||||||
seed: None,
|
seed: None,
|
||||||
},
|
},
|
||||||
@ -48,6 +50,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Generate method
|
/// Generate method
|
||||||
|
#[utoipa::path(post, path = "/generate", request_body = GenerateRequest)]
|
||||||
#[instrument(
|
#[instrument(
|
||||||
skip(infer),
|
skip(infer),
|
||||||
fields(
|
fields(
|
||||||
@ -76,7 +79,7 @@ async fn generate(
|
|||||||
// Token details
|
// Token details
|
||||||
let details = match details {
|
let details = match details {
|
||||||
true => Some(Details {
|
true => Some(Details {
|
||||||
finish_reason: response.generated_text.finish_reason,
|
finish_reason: FinishReason::from(response.generated_text.finish_reason),
|
||||||
generated_tokens: response.generated_text.generated_tokens,
|
generated_tokens: response.generated_text.generated_tokens,
|
||||||
prefill: Some(response.prefill),
|
prefill: Some(response.prefill),
|
||||||
tokens: Some(response.tokens),
|
tokens: Some(response.tokens),
|
||||||
@ -133,6 +136,7 @@ async fn generate(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Generate stream method
|
/// Generate stream method
|
||||||
|
#[utoipa::path(post, path = "/generate_stream")]
|
||||||
#[instrument(
|
#[instrument(
|
||||||
skip(infer),
|
skip(infer),
|
||||||
fields(
|
fields(
|
||||||
@ -185,11 +189,9 @@ async fn generate_stream(
|
|||||||
} => {
|
} => {
|
||||||
// Token details
|
// Token details
|
||||||
let details = match details {
|
let details = match details {
|
||||||
true => Some(Details {
|
true => Some(StreamDetails {
|
||||||
finish_reason: generated_text.finish_reason,
|
finish_reason: FinishReason::from(generated_text.finish_reason),
|
||||||
generated_tokens: generated_text.generated_tokens,
|
generated_tokens: generated_text.generated_tokens,
|
||||||
prefill: None,
|
|
||||||
tokens: None,
|
|
||||||
seed: generated_text.seed,
|
seed: generated_text.seed,
|
||||||
}),
|
}),
|
||||||
false => None,
|
false => None,
|
||||||
@ -265,6 +267,33 @@ pub async fn run(
|
|||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
addr: SocketAddr,
|
addr: SocketAddr,
|
||||||
) {
|
) {
|
||||||
|
// OpenAPI documentation
|
||||||
|
#[derive(OpenApi)]
|
||||||
|
#[openapi(
|
||||||
|
paths(
|
||||||
|
generate,
|
||||||
|
generate_stream,
|
||||||
|
),
|
||||||
|
components(
|
||||||
|
schemas(
|
||||||
|
GenerateRequest,
|
||||||
|
GenerateParameters,
|
||||||
|
Token,
|
||||||
|
GenerateResponse,
|
||||||
|
Details,
|
||||||
|
FinishReason,
|
||||||
|
StreamResponse,
|
||||||
|
StreamDetails,
|
||||||
|
ErrorResponse,
|
||||||
|
ErrorType
|
||||||
|
)
|
||||||
|
),
|
||||||
|
tags(
|
||||||
|
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
struct ApiDoc;
|
||||||
|
|
||||||
// Create state
|
// Create state
|
||||||
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
||||||
let infer = Infer::new(
|
let infer = Infer::new(
|
||||||
@ -277,6 +306,7 @@ pub async fn run(
|
|||||||
|
|
||||||
// Create router
|
// Create router
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
|
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
||||||
.route("/", post(generate))
|
.route("/", post(generate))
|
||||||
.route("/generate", post(generate))
|
.route("/generate", post(generate))
|
||||||
.route("/generate_stream", post(generate_stream))
|
.route("/generate_stream", post(generate_stream))
|
||||||
@ -320,6 +350,30 @@ async fn shutdown_signal() {
|
|||||||
tracing::info!("signal received, starting graceful shutdown");
|
tracing::info!("signal received, starting graceful shutdown");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<i32> for FinishReason {
|
||||||
|
fn from(finish_reason: i32) -> Self {
|
||||||
|
let finish_reason = text_generation_client::FinishReason::from_i32(finish_reason).unwrap();
|
||||||
|
match finish_reason {
|
||||||
|
text_generation_client::FinishReason::Length => FinishReason::Length,
|
||||||
|
text_generation_client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||||
|
text_generation_client::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<InferError> for ErrorResponse {
|
||||||
|
fn from(err: InferError) -> Self {
|
||||||
|
let err_string = err.to_string();
|
||||||
|
let error = match err {
|
||||||
|
InferError::GenerationError(_) => ErrorType::GenerationError(err_string),
|
||||||
|
InferError::Overloaded(_) => ErrorType::Overloaded(err_string),
|
||||||
|
InferError::ValidationError(_) => ErrorType::ValidationError(err_string),
|
||||||
|
InferError::IncompleteGeneration => ErrorType::IncompleteGeneration(err_string),
|
||||||
|
};
|
||||||
|
ErrorResponse { error }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Convert to Axum supported formats
|
/// Convert to Axum supported formats
|
||||||
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
||||||
fn from(err: InferError) -> Self {
|
fn from(err: InferError) -> Self {
|
||||||
@ -330,21 +384,14 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
|||||||
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
};
|
};
|
||||||
|
|
||||||
(
|
(status_code, Json(ErrorResponse::from(err)))
|
||||||
status_code,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: err.to_string(),
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<InferError> for Event {
|
impl From<InferError> for Event {
|
||||||
fn from(err: InferError) -> Self {
|
fn from(err: InferError) -> Self {
|
||||||
Event::default()
|
Event::default()
|
||||||
.json_data(ErrorResponse {
|
.json_data(ErrorResponse::from(err))
|
||||||
error: err.to_string(),
|
|
||||||
})
|
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -110,30 +110,58 @@ fn validate(
|
|||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
rng: &mut ThreadRng,
|
rng: &mut ThreadRng,
|
||||||
) -> Result<ValidGenerateRequest, ValidationError> {
|
) -> Result<ValidGenerateRequest, ValidationError> {
|
||||||
if request.parameters.temperature <= 0.0 {
|
let GenerateParameters {
|
||||||
|
temperature,
|
||||||
|
repetition_penalty,
|
||||||
|
top_k,
|
||||||
|
top_p,
|
||||||
|
do_sample,
|
||||||
|
max_new_tokens,
|
||||||
|
stop: stop_sequences,
|
||||||
|
seed,
|
||||||
|
..
|
||||||
|
} = request.parameters;
|
||||||
|
|
||||||
|
let temperature = temperature.unwrap_or(1.0);
|
||||||
|
if temperature <= 0.0 {
|
||||||
return Err(ValidationError::Temperature);
|
return Err(ValidationError::Temperature);
|
||||||
}
|
}
|
||||||
if request.parameters.repetition_penalty <= 0.0 {
|
|
||||||
|
let repetition_penalty = repetition_penalty.unwrap_or(1.0);
|
||||||
|
if repetition_penalty <= 0.0 {
|
||||||
return Err(ValidationError::RepetitionPenalty);
|
return Err(ValidationError::RepetitionPenalty);
|
||||||
}
|
}
|
||||||
if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 {
|
|
||||||
|
let top_p = top_p.unwrap_or(1.0);
|
||||||
|
if top_p <= 0.0 || top_p > 1.0 {
|
||||||
return Err(ValidationError::TopP);
|
return Err(ValidationError::TopP);
|
||||||
}
|
}
|
||||||
if request.parameters.top_k < 0 {
|
|
||||||
return Err(ValidationError::TopK);
|
// Different because the proto default value is 0 while it is not a valid value
|
||||||
}
|
// for the user
|
||||||
if request.parameters.max_new_tokens > MAX_MAX_NEW_TOKENS {
|
let top_k: u32 = match top_k {
|
||||||
|
None => Ok(0),
|
||||||
|
Some(top_k) => {
|
||||||
|
if top_k <= 0 {
|
||||||
|
return Err(ValidationError::TopK);
|
||||||
|
}
|
||||||
|
Ok(top_k as u32)
|
||||||
|
}
|
||||||
|
}?;
|
||||||
|
|
||||||
|
if max_new_tokens <= 0 || max_new_tokens > MAX_MAX_NEW_TOKENS {
|
||||||
return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS));
|
return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS));
|
||||||
}
|
}
|
||||||
if request.parameters.stop.len() > MAX_STOP_SEQUENCES {
|
|
||||||
|
if stop_sequences.len() > MAX_STOP_SEQUENCES {
|
||||||
return Err(ValidationError::StopSequence(
|
return Err(ValidationError::StopSequence(
|
||||||
MAX_STOP_SEQUENCES,
|
MAX_STOP_SEQUENCES,
|
||||||
request.parameters.stop.len(),
|
stop_sequences.len(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// If seed is None, assign a random one
|
// If seed is None, assign a random one
|
||||||
let seed = match request.parameters.seed {
|
let seed = match seed {
|
||||||
None => rng.gen(),
|
None => rng.gen(),
|
||||||
Some(seed) => seed,
|
Some(seed) => seed,
|
||||||
};
|
};
|
||||||
@ -147,21 +175,10 @@ fn validate(
|
|||||||
Err(ValidationError::InputLength(input_length, max_input_length))
|
Err(ValidationError::InputLength(input_length, max_input_length))
|
||||||
} else {
|
} else {
|
||||||
// Return ValidGenerateRequest
|
// Return ValidGenerateRequest
|
||||||
let GenerateParameters {
|
|
||||||
temperature,
|
|
||||||
repetition_penalty,
|
|
||||||
top_k,
|
|
||||||
top_p,
|
|
||||||
do_sample,
|
|
||||||
max_new_tokens,
|
|
||||||
stop: stop_sequences,
|
|
||||||
..
|
|
||||||
} = request.parameters;
|
|
||||||
|
|
||||||
let parameters = NextTokenChooserParameters {
|
let parameters = NextTokenChooserParameters {
|
||||||
temperature,
|
temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
top_k: top_k as u32,
|
top_k,
|
||||||
top_p,
|
top_p,
|
||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
@ -206,7 +223,7 @@ pub enum ValidationError {
|
|||||||
TopP,
|
TopP,
|
||||||
#[error("top_k must be strictly positive")]
|
#[error("top_k must be strictly positive")]
|
||||||
TopK,
|
TopK,
|
||||||
#[error("max_new_tokens must be <= {0}")]
|
#[error("max_new_tokens must be strictly positive and <= {0}")]
|
||||||
MaxNewTokens(u32),
|
MaxNewTokens(u32),
|
||||||
#[error("inputs must have less than {1} tokens. Given: {0}")]
|
#[error("inputs must have less than {1} tokens. Given: {0}")]
|
||||||
InputLength(usize, usize),
|
InputLength(usize, usize),
|
||||||
|
@ -9,6 +9,7 @@ from text_generation.utils import (
|
|||||||
StopSequenceCriteria,
|
StopSequenceCriteria,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
LocalEntryNotFoundError,
|
LocalEntryNotFoundError,
|
||||||
|
FinishReason
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -24,13 +25,13 @@ def test_stop_sequence_criteria():
|
|||||||
def test_stopping_criteria():
|
def test_stopping_criteria():
|
||||||
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
||||||
assert criteria(65827, "/test") == (False, None)
|
assert criteria(65827, "/test") == (False, None)
|
||||||
assert criteria(30, ";") == (True, "stop_sequence")
|
assert criteria(30, ";") == (True, FinishReason.FINISH_REASON_STOP_SEQUENCE)
|
||||||
|
|
||||||
|
|
||||||
def test_stopping_criteria_eos():
|
def test_stopping_criteria_eos():
|
||||||
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
||||||
assert criteria(1, "") == (False, None)
|
assert criteria(1, "") == (False, None)
|
||||||
assert criteria(0, "") == (True, "eos_token")
|
assert criteria(0, "") == (True, FinishReason.FINISH_REASON_EOS_TOKEN)
|
||||||
|
|
||||||
|
|
||||||
def test_stopping_criteria_max():
|
def test_stopping_criteria_max():
|
||||||
@ -39,7 +40,7 @@ def test_stopping_criteria_max():
|
|||||||
assert criteria(1, "") == (False, None)
|
assert criteria(1, "") == (False, None)
|
||||||
assert criteria(1, "") == (False, None)
|
assert criteria(1, "") == (False, None)
|
||||||
assert criteria(1, "") == (False, None)
|
assert criteria(1, "") == (False, None)
|
||||||
assert criteria(1, "") == (True, "length")
|
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
|
||||||
|
|
||||||
|
|
||||||
def test_weight_hub_files():
|
def test_weight_hub_files():
|
||||||
|
@ -7,6 +7,7 @@ from typing import List, Optional
|
|||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation.pb import generate_pb2
|
||||||
|
from text_generation.pb.generate_pb2 import FinishReason
|
||||||
|
|
||||||
|
|
||||||
class Batch(ABC):
|
class Batch(ABC):
|
||||||
@ -38,7 +39,7 @@ class Batch(ABC):
|
|||||||
class GeneratedText:
|
class GeneratedText:
|
||||||
text: str
|
text: str
|
||||||
generated_tokens: int
|
generated_tokens: int
|
||||||
finish_reason: str
|
finish_reason: FinishReason
|
||||||
seed: Optional[int]
|
seed: Optional[int]
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.GeneratedText:
|
def to_pb(self) -> generate_pb2.GeneratedText:
|
||||||
|
@ -24,6 +24,7 @@ from transformers.generation.logits_process import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation.pb import generate_pb2
|
||||||
|
from text_generation.pb.generate_pb2 import FinishReason
|
||||||
|
|
||||||
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
||||||
|
|
||||||
@ -129,15 +130,15 @@ class StoppingCriteria:
|
|||||||
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
|
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
|
||||||
self.current_tokens += 1
|
self.current_tokens += 1
|
||||||
if self.current_tokens >= self.max_new_tokens:
|
if self.current_tokens >= self.max_new_tokens:
|
||||||
return True, "length"
|
return True, FinishReason.FINISH_REASON_LENGTH
|
||||||
|
|
||||||
if last_token == self.eos_token_id:
|
if last_token == self.eos_token_id:
|
||||||
return True, "eos_token"
|
return True, FinishReason.FINISH_REASON_EOS_TOKEN
|
||||||
|
|
||||||
self.current_output += last_output
|
self.current_output += last_output
|
||||||
for stop_sequence_criteria in self.stop_sequence_criterias:
|
for stop_sequence_criteria in self.stop_sequence_criterias:
|
||||||
if stop_sequence_criteria(self.current_output):
|
if stop_sequence_criteria(self.current_output):
|
||||||
return True, "stop_sequence"
|
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
|
||||||
|
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user