mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
rework validation
This commit is contained in:
parent
47e93409f3
commit
1dd2c24b9c
356
Cargo.lock
generated
356
Cargo.lock
generated
@ -8,18 +8,6 @@ version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
|
||||
|
||||
[[package]]
|
||||
name = "aes"
|
||||
version = "0.7.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9e8b47f52ea9bae42228d07ec09eb676433d7c4ed1ebdf0f1d1c29ed446f1ab8"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cipher",
|
||||
"cpufeatures",
|
||||
"opaque-debug",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.7.6"
|
||||
@ -40,6 +28,15 @@ dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ansi_term"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstream"
|
||||
version = "0.2.6"
|
||||
@ -119,6 +116,17 @@ dependencies = [
|
||||
"syn 2.0.12",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atty"
|
||||
version = "0.2.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
|
||||
dependencies = [
|
||||
"hermit-abi 0.1.19",
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.1.0"
|
||||
@ -202,12 +210,6 @@ version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a"
|
||||
|
||||
[[package]]
|
||||
name = "base64ct"
|
||||
version = "1.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "1.3.2"
|
||||
@ -264,9 +266,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "cached-path"
|
||||
version = "0.6.1"
|
||||
version = "0.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "097968e38f1319207f057d0f4d76452e4f4f847a5de61c5215379f297fa034f3"
|
||||
checksum = "5f1c56d30236522ab3393a08746b138d4e16372001f42d29c88d513aeb8ab7ef"
|
||||
dependencies = [
|
||||
"flate2",
|
||||
"fs2",
|
||||
@ -281,7 +283,8 @@ dependencies = [
|
||||
"tar",
|
||||
"tempfile",
|
||||
"thiserror",
|
||||
"zip",
|
||||
"zip 0.5.13",
|
||||
"zip-extensions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -289,9 +292,6 @@ name = "cc"
|
||||
version = "1.0.79"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f"
|
||||
dependencies = [
|
||||
"jobserver",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
@ -300,12 +300,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "cipher"
|
||||
version = "0.3.0"
|
||||
name = "clap"
|
||||
version = "2.34.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ee52072ec15386f770805afd189a01c8841be8696bed250fa2f13c4c0d6dfb7"
|
||||
checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
"ansi_term",
|
||||
"atty",
|
||||
"bitflags",
|
||||
"strsim 0.8.0",
|
||||
"textwrap",
|
||||
"unicode-width",
|
||||
"vec_map",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -329,7 +335,7 @@ dependencies = [
|
||||
"anstyle",
|
||||
"bitflags",
|
||||
"clap_lex",
|
||||
"strsim",
|
||||
"strsim 0.10.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -378,12 +384,6 @@ dependencies = [
|
||||
"windows-sys 0.42.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "constant_time_eq"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc"
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.9.3"
|
||||
@ -483,9 +483,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "darling"
|
||||
version = "0.14.4"
|
||||
version = "0.10.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850"
|
||||
checksum = "0d706e75d87e35569db781a9b5e2416cff1236a47ed380831f959382ccd5f858"
|
||||
dependencies = [
|
||||
"darling_core",
|
||||
"darling_macro",
|
||||
@ -493,23 +493,23 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "darling_core"
|
||||
version = "0.14.4"
|
||||
version = "0.10.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0"
|
||||
checksum = "f0c960ae2da4de88a91b2d920c2a7233b400bc33cb28453a2987822d8392519b"
|
||||
dependencies = [
|
||||
"fnv",
|
||||
"ident_case",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"strsim",
|
||||
"strsim 0.9.3",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_macro"
|
||||
version = "0.14.4"
|
||||
version = "0.10.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e"
|
||||
checksum = "d9b5a2f4ac4969822c62224815d069952656cadc7084fdca9751e6d959189b72"
|
||||
dependencies = [
|
||||
"darling_core",
|
||||
"quote",
|
||||
@ -531,32 +531,26 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "derive_builder"
|
||||
version = "0.12.0"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8"
|
||||
dependencies = [
|
||||
"derive_builder_macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_builder_core"
|
||||
version = "0.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f"
|
||||
checksum = "a2658621297f2cf68762a6f7dc0bb7e1ff2cfd6583daef8ee0fed6f7ec468ec0"
|
||||
dependencies = [
|
||||
"darling",
|
||||
"derive_builder_core",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_builder_macro"
|
||||
version = "0.12.0"
|
||||
name = "derive_builder_core"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e"
|
||||
checksum = "2791ea3e372c8495c0bc2033991d76b512cd799d07491fbd6890124db9458bef"
|
||||
dependencies = [
|
||||
"derive_builder_core",
|
||||
"darling",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
@ -568,7 +562,15 @@ checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f"
|
||||
dependencies = [
|
||||
"block-buffer",
|
||||
"crypto-common",
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dirs"
|
||||
version = "3.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "30baa043103c9d0c2a57cf537cc2f35623889dc0d405e6c3cccfadbc81c71309"
|
||||
dependencies = [
|
||||
"dirs-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -685,6 +687,19 @@ version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853"
|
||||
|
||||
[[package]]
|
||||
name = "flume"
|
||||
version = "0.10.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1657b4441c3403d9f7b3409e47575237dac27b1b5726df654a6ecbf92f0f7577"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
"nanorand",
|
||||
"pin-project",
|
||||
"spin",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.7"
|
||||
@ -831,8 +846,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"js-sys",
|
||||
"libc",
|
||||
"wasi 0.11.0+wasi-snapshot-preview1",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -885,6 +902,15 @@ version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.1.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.2.6"
|
||||
@ -900,15 +926,6 @@ version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286"
|
||||
|
||||
[[package]]
|
||||
name = "hmac"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
|
||||
dependencies = [
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "0.2.9"
|
||||
@ -1120,15 +1137,6 @@ version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6"
|
||||
|
||||
[[package]]
|
||||
name = "jobserver"
|
||||
version = "0.1.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "936cfd212a0155903bcbc060e316fb6cc7cbf2e1907329391ebadc1fe0ce77c2"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "js-sys"
|
||||
version = "0.3.61"
|
||||
@ -1331,33 +1339,21 @@ dependencies = [
|
||||
"windows-sys 0.45.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "monostate"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0230b703f1ac35df1e24f6d0d2255472bcccaf657ecdfa4f1fcbcad1ad5bb98a"
|
||||
dependencies = [
|
||||
"monostate-impl",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "monostate-impl"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8795add3e14028f11f8e848bd3294898a8294767b3776b6f733560d33bd2530b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.12",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "multimap"
|
||||
version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
|
||||
|
||||
[[package]]
|
||||
name = "nanorand"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "native-tls"
|
||||
version = "0.2.11"
|
||||
@ -1464,12 +1460,6 @@ dependencies = [
|
||||
"pkg-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opaque-debug"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
|
||||
|
||||
[[package]]
|
||||
name = "openssl"
|
||||
version = "0.10.48"
|
||||
@ -1624,35 +1614,12 @@ dependencies = [
|
||||
"windows-sys 0.45.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "password-hash"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700"
|
||||
dependencies = [
|
||||
"base64ct",
|
||||
"rand_core",
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "paste"
|
||||
version = "1.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79"
|
||||
|
||||
[[package]]
|
||||
name = "pbkdf2"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917"
|
||||
dependencies = [
|
||||
"digest",
|
||||
"hmac",
|
||||
"password-hash",
|
||||
"sha2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "percent-encoding"
|
||||
version = "2.2.0"
|
||||
@ -2165,17 +2132,6 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha1"
|
||||
version = "0.10.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha2"
|
||||
version = "0.10.6"
|
||||
@ -2202,7 +2158,7 @@ version = "2.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ccc8076840c4da029af4f87e4e8daeb0fca6b87bbb02e10cb60b791450e11e4"
|
||||
dependencies = [
|
||||
"dirs",
|
||||
"dirs 4.0.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2245,6 +2201,15 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spin"
|
||||
version = "0.9.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
|
||||
dependencies = [
|
||||
"lock_api",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spm_precompiled"
|
||||
version = "0.1.4"
|
||||
@ -2263,6 +2228,18 @@ version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a"
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.9.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c"
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.10.0"
|
||||
@ -2279,12 +2256,6 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "subtle"
|
||||
version = "2.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.109"
|
||||
@ -2358,7 +2329,7 @@ dependencies = [
|
||||
name = "text-generation-launcher"
|
||||
version = "0.4.3"
|
||||
dependencies = [
|
||||
"clap",
|
||||
"clap 4.2.1",
|
||||
"ctrlc",
|
||||
"float_eq",
|
||||
"reqwest",
|
||||
@ -2376,14 +2347,14 @@ dependencies = [
|
||||
"async-stream",
|
||||
"axum",
|
||||
"axum-tracing-opentelemetry",
|
||||
"clap",
|
||||
"clap 4.2.1",
|
||||
"flume",
|
||||
"futures",
|
||||
"metrics",
|
||||
"metrics-exporter-prometheus",
|
||||
"nohash-hasher",
|
||||
"opentelemetry",
|
||||
"opentelemetry-otlp",
|
||||
"parking_lot",
|
||||
"rand",
|
||||
"reqwest",
|
||||
"serde",
|
||||
@ -2392,7 +2363,6 @@ dependencies = [
|
||||
"thiserror",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
"tracing-opentelemetry",
|
||||
@ -2401,6 +2371,15 @@ dependencies = [
|
||||
"utoipa-swagger-ui",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "textwrap"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060"
|
||||
dependencies = [
|
||||
"unicode-width",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.40"
|
||||
@ -2433,20 +2412,14 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "time"
|
||||
version = "0.3.20"
|
||||
version = "0.1.43"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cd0cbfecb4d19b5ea75bb31ad904eb5b9fa13f21079c3b92017ebdf4999a5890"
|
||||
checksum = "ca8a50ef2360fbd1eeb0ecd46795a87a19024eb4b53c5dc916ca1fd95fe62438"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"time-core",
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "time-core"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd"
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec"
|
||||
version = "1.6.0"
|
||||
@ -2465,13 +2438,14 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
||||
[[package]]
|
||||
name = "tokenizers"
|
||||
version = "0.13.2"
|
||||
source = "git+https://github.com/huggingface/tokenizers.git#3aaf4946b3c82e7a04db6ecde7c8bb4e474e54af"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f4ff2dd291eac98dcea13e8cf7a0b28c373a90dc9210ccdab0fa9e69ee0cac69"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"cached-path",
|
||||
"clap",
|
||||
"clap 2.34.0",
|
||||
"derive_builder",
|
||||
"dirs",
|
||||
"dirs 3.0.2",
|
||||
"esaxx-rs",
|
||||
"getrandom",
|
||||
"indicatif 0.15.0",
|
||||
@ -2479,7 +2453,6 @@ dependencies = [
|
||||
"lazy_static",
|
||||
"log",
|
||||
"macro_rules_attribute",
|
||||
"monostate",
|
||||
"onig",
|
||||
"paste",
|
||||
"rand",
|
||||
@ -2901,7 +2874,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"utoipa",
|
||||
"zip",
|
||||
"zip 0.6.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2916,6 +2889,12 @@ version = "0.2.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
|
||||
|
||||
[[package]]
|
||||
name = "vec_map"
|
||||
version = "0.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191"
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.4"
|
||||
@ -3171,52 +3150,37 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zip"
|
||||
version = "0.5.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "93ab48844d61251bb3835145c521d88aa4031d7139e8485990f60ca911fa0815"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"bzip2",
|
||||
"crc32fast",
|
||||
"flate2",
|
||||
"thiserror",
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zip"
|
||||
version = "0.6.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0445d0fbc924bb93539b4316c11afb121ea39296f99a3c4c9edad09e3658cdef"
|
||||
dependencies = [
|
||||
"aes",
|
||||
"byteorder",
|
||||
"bzip2",
|
||||
"constant_time_eq",
|
||||
"crc32fast",
|
||||
"crossbeam-utils",
|
||||
"flate2",
|
||||
"hmac",
|
||||
"pbkdf2",
|
||||
"sha1",
|
||||
"time",
|
||||
"zstd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstd"
|
||||
version = "0.11.2+zstd.1.5.2"
|
||||
name = "zip-extensions"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4"
|
||||
checksum = "a64c3c977bc3434ce2d4bcea8ad3c644672de0f2c402b72b9171ca80a8885d14"
|
||||
dependencies = [
|
||||
"zstd-safe",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstd-safe"
|
||||
version = "5.0.2+zstd.1.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"zstd-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstd-sys"
|
||||
version = "2.0.7+zstd.1.5.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94509c3ba2fe55294d752b79842c530ccfab760192521df74a081a78d2b3c7f5"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
"pkg-config",
|
||||
"zip 0.5.13",
|
||||
]
|
||||
|
@ -27,8 +27,7 @@ serde = {version = "1.0.142", features = ["derive"]}
|
||||
serde_json = "1.0"
|
||||
text-generation-client = { path = "../router/client" }
|
||||
thiserror = "1.0.38"
|
||||
#tokenizers = "0.13.2"
|
||||
tokenizers = { git = "https://github.com/huggingface/tokenizers.git" }
|
||||
tokenizers = "0.13.2"
|
||||
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tui = {package = "ratatui", version = "0.20", default-features = false, features = ["crossterm"]}
|
||||
tracing = "0.1.37"
|
||||
|
@ -18,22 +18,20 @@ axum = { version = "0.6.4", features = ["json"] }
|
||||
axum-tracing-opentelemetry = "0.9.0"
|
||||
text-generation-client = { path = "client" }
|
||||
clap = { version = "4.1.4", features = ["derive", "env"] }
|
||||
flume = "0.10.14"
|
||||
futures = "0.3.26"
|
||||
metrics = "0.20.1"
|
||||
metrics-exporter-prometheus = { version = "0.11.0", features = [] }
|
||||
nohash-hasher = "0.2.0"
|
||||
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
|
||||
opentelemetry-otlp = "0.11.0"
|
||||
parking_lot = "0.12.1"
|
||||
rand = "0.8.5"
|
||||
reqwest = { version = "0.11.14", features = [] }
|
||||
serde = "1.0.152"
|
||||
serde_json = "1.0.93"
|
||||
thiserror = "1.0.38"
|
||||
#tokenizers = "0.13.2"
|
||||
tokenizers = { git = "https://github.com/huggingface/tokenizers.git" }
|
||||
tokenizers = "0.13.2"
|
||||
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tokio-stream = "0.1.11"
|
||||
tower-http = { version = "0.3.5", features = ["cors"] }
|
||||
tracing = "0.1.37"
|
||||
tracing-opentelemetry = "0.18.0"
|
||||
|
@ -2,17 +2,17 @@
|
||||
use crate::validation::{Validation, ValidationError};
|
||||
use crate::{Entry, Queue, Token};
|
||||
use crate::{GenerateRequest, PrefillToken};
|
||||
use flume::r#async::RecvStream;
|
||||
use futures::future::try_join_all;
|
||||
use futures::stream::StreamExt;
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{
|
||||
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
|
||||
use tokio::sync::{Notify, Semaphore, TryAcquireError};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
|
||||
/// Inference struct
|
||||
@ -73,7 +73,7 @@ impl Infer {
|
||||
pub(crate) async fn generate_stream(
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||
) -> Result<RecvStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||
// This permit will live as long as Entry
|
||||
let permit = self
|
||||
@ -90,7 +90,7 @@ impl Infer {
|
||||
let valid_request = self.validation.validate(request).await?;
|
||||
|
||||
// MPSC channel to communicate with the background batching task
|
||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||
let (response_tx, response_rx) = flume::unbounded();
|
||||
|
||||
// Append the request to the queue
|
||||
self.queue.append(Entry {
|
||||
@ -108,7 +108,8 @@ impl Infer {
|
||||
self.shared.batching_task.notify_one();
|
||||
|
||||
// Return stream
|
||||
Ok(UnboundedReceiverStream::new(response_rx))
|
||||
Ok(response_rx.into_stream())
|
||||
// Ok(UnboundedReceiverStream::new(response_rx))
|
||||
}
|
||||
|
||||
/// Add a new request to the queue and return a InferResponse
|
||||
|
@ -4,8 +4,7 @@ use crate::validation::ValidGenerateRequest;
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::min;
|
||||
use text_generation_client::{Batch, Request};
|
||||
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
|
||||
use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit};
|
||||
use tokio::sync::{oneshot, OwnedSemaphorePermit};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{info_span, instrument, Span};
|
||||
|
||||
@ -15,7 +14,7 @@ pub(crate) struct Entry {
|
||||
/// Request
|
||||
pub request: ValidGenerateRequest,
|
||||
/// Response sender to communicate between the Infer struct and the batching_task
|
||||
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
|
||||
pub response_tx: flume::Sender<Result<InferStreamResponse, InferError>>,
|
||||
/// Span that will live as long as entry
|
||||
pub span: Span,
|
||||
/// Temporary span used as a guard when logging inference, wait times...
|
||||
@ -32,13 +31,13 @@ pub(crate) struct Entry {
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Queue {
|
||||
/// Channel to communicate with the background queue task
|
||||
queue_sender: UnboundedSender<QueueCommand>,
|
||||
queue_sender: flume::Sender<QueueCommand>,
|
||||
}
|
||||
|
||||
impl Queue {
|
||||
pub(crate) fn new() -> Self {
|
||||
// Create channel
|
||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||
let (queue_sender, queue_receiver) = flume::unbounded();
|
||||
|
||||
// Launch background queue task
|
||||
tokio::spawn(queue_task(queue_receiver));
|
||||
@ -82,10 +81,10 @@ impl Queue {
|
||||
}
|
||||
|
||||
// Background task responsible of the queue state
|
||||
async fn queue_task(mut receiver: UnboundedReceiver<QueueCommand>) {
|
||||
async fn queue_task(receiver: flume::Receiver<QueueCommand>) {
|
||||
let mut state = State::new();
|
||||
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
while let Ok(cmd) = receiver.recv_async().await {
|
||||
match cmd {
|
||||
QueueCommand::Append(entry, span) => span.in_scope(|| state.append(entry)),
|
||||
QueueCommand::NextBatch {
|
||||
@ -216,12 +215,12 @@ mod tests {
|
||||
use super::*;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||
use tokio::sync::{mpsc, Semaphore};
|
||||
use tokio::sync::Semaphore;
|
||||
use tracing::info_span;
|
||||
|
||||
fn default_entry() -> Entry {
|
||||
let semaphore = Arc::new(Semaphore::new(1));
|
||||
let (response_tx, _) = mpsc::unbounded_channel();
|
||||
let (response_tx, _) = flume::unbounded();
|
||||
let permit = semaphore.try_acquire_owned().unwrap();
|
||||
|
||||
Entry {
|
||||
|
@ -13,6 +13,7 @@ use axum::response::{IntoResponse, Response};
|
||||
use axum::routing::{get, post};
|
||||
use axum::{http, Json, Router};
|
||||
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
||||
use futures::stream::StreamExt;
|
||||
use futures::Stream;
|
||||
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
|
||||
use std::convert::Infallible;
|
||||
@ -21,7 +22,6 @@ use text_generation_client::ShardedClient;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::signal;
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::StreamExt;
|
||||
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||
use tracing::{info_span, instrument, Instrument};
|
||||
use utoipa::OpenApi;
|
||||
|
@ -1,23 +1,24 @@
|
||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||
/// Payload validation logic
|
||||
use crate::{GenerateParameters, GenerateRequest};
|
||||
use rand::rngs::ThreadRng;
|
||||
use rand::Rng;
|
||||
use rand::{thread_rng, Rng};
|
||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokenizers::TruncationDirection;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::{instrument, Span};
|
||||
|
||||
/// Validation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Validation {
|
||||
/// maximum value for the best_of parameter
|
||||
#[allow(dead_code)]
|
||||
/// Validation parameters
|
||||
max_best_of: usize,
|
||||
/// Channel to communicate with the background validation task
|
||||
sender: mpsc::UnboundedSender<ValidationRequest>,
|
||||
max_stop_sequences: usize,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
/// Channel to communicate with the background tokenization task
|
||||
sender: Option<flume::Sender<TokenizerRequest>>,
|
||||
}
|
||||
|
||||
impl Validation {
|
||||
@ -29,22 +30,81 @@ impl Validation {
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
) -> Self {
|
||||
// If we have a fast tokenizer
|
||||
let sender = if let Some(tokenizer) = tokenizer {
|
||||
// Create channel
|
||||
let (validation_sender, validation_receiver) = mpsc::unbounded_channel();
|
||||
let (validation_sender, validation_receiver) = flume::unbounded();
|
||||
|
||||
// Launch background validation task
|
||||
tokio::spawn(validation_task(
|
||||
workers,
|
||||
tokenizer,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
validation_receiver,
|
||||
));
|
||||
// Create workers
|
||||
for _ in 0..workers {
|
||||
let tokenizer_clone = tokenizer.clone();
|
||||
let receiver_clone = validation_receiver.clone();
|
||||
|
||||
// Spawn worker
|
||||
tokio::task::spawn_blocking(move || {
|
||||
tokenizer_worker(tokenizer_clone.into(), receiver_clone)
|
||||
});
|
||||
}
|
||||
Some(validation_sender)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Self {
|
||||
max_best_of,
|
||||
sender: validation_sender,
|
||||
sender,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
async fn validate_input(
|
||||
&self,
|
||||
inputs: String,
|
||||
truncate: Option<usize>,
|
||||
max_new_tokens: u32,
|
||||
) -> Result<String, ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
if let Some(sender) = &self.sender {
|
||||
// Create response channel
|
||||
let (response_sender, response_receiver) = oneshot::channel();
|
||||
// Send request to the background validation task
|
||||
// Unwrap is safe here
|
||||
sender
|
||||
.send(((inputs, truncate), response_sender, Span::current()))
|
||||
.unwrap();
|
||||
|
||||
// Await on response channel
|
||||
// Unwrap is safe here
|
||||
let (inputs, input_length) = response_receiver.await.unwrap()?;
|
||||
|
||||
// Get total tokens
|
||||
let total_tokens = input_length + max_new_tokens as usize;
|
||||
|
||||
// Validate MaxTotalTokens
|
||||
if total_tokens > self.max_total_tokens {
|
||||
return Err(ValidationError::MaxTotalTokens(
|
||||
self.max_total_tokens,
|
||||
input_length,
|
||||
max_new_tokens,
|
||||
));
|
||||
}
|
||||
|
||||
// Validate InputLength
|
||||
if input_length > self.max_input_length {
|
||||
return Err(ValidationError::InputLength(
|
||||
self.max_input_length,
|
||||
input_length,
|
||||
));
|
||||
}
|
||||
|
||||
metrics::histogram!("tgi_request_input_length", input_length as f64);
|
||||
Ok(inputs)
|
||||
}
|
||||
// Return inputs without validation
|
||||
else {
|
||||
Ok(inputs)
|
||||
}
|
||||
}
|
||||
|
||||
@ -53,120 +113,6 @@ impl Validation {
|
||||
pub(crate) async fn validate(
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
) -> Result<ValidGenerateRequest, ValidationError> {
|
||||
// Create response channel
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
// Send request to the background validation task
|
||||
// Unwrap is safe here
|
||||
self.sender
|
||||
.send((request, sender, Span::current()))
|
||||
.unwrap();
|
||||
// Await on response channel
|
||||
// Unwrap is safe here
|
||||
receiver.await.unwrap()
|
||||
}
|
||||
|
||||
/// Validate the best_of parameter
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) fn validate_best_of(&self, best_of: usize) -> Result<usize, ValidationError> {
|
||||
if self.max_best_of == 1 && best_of != 1 {
|
||||
return Err(ValidationError::BestOfDisabled);
|
||||
}
|
||||
|
||||
if best_of > self.max_best_of {
|
||||
return Err(ValidationError::BestOf(self.max_best_of, best_of));
|
||||
}
|
||||
|
||||
Ok(best_of)
|
||||
}
|
||||
}
|
||||
|
||||
/// Validation task
|
||||
/// Load balance the validation requests between multiple validation workers
|
||||
async fn validation_task(
|
||||
workers: usize,
|
||||
tokenizer: Option<Tokenizer>,
|
||||
max_stop_sequences: usize,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
mut receiver: mpsc::UnboundedReceiver<ValidationRequest>,
|
||||
) {
|
||||
let mut workers_senders = Vec::with_capacity(workers);
|
||||
|
||||
// Create workers
|
||||
for _ in 0..workers {
|
||||
let tokenizer_clone: Option<Tokenizer> = tokenizer.clone().into();
|
||||
// Create channel to communicate with worker
|
||||
let (worker_sender, worker_receiver) = mpsc::channel(workers);
|
||||
workers_senders.push(worker_sender);
|
||||
|
||||
// Spawn worker
|
||||
tokio::task::spawn_blocking(move || {
|
||||
validation_worker(
|
||||
tokenizer_clone,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
worker_receiver,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
loop {
|
||||
// Load balance requests between workers
|
||||
for sender in workers_senders.iter() {
|
||||
if let Some(validation_request) = receiver.recv().await {
|
||||
sender.send(validation_request).await.unwrap();
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check the parameters inside the payload and get the number of tokens inside the input using
|
||||
/// the tokenizer
|
||||
fn validation_worker(
|
||||
tokenizer: Option<Tokenizer>,
|
||||
max_stop_sequences: usize,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
mut receiver: mpsc::Receiver<ValidationRequest>,
|
||||
) {
|
||||
// Seed rng
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
// Loop over requests
|
||||
while let Some((request, response_tx, parent_span)) = receiver.blocking_recv() {
|
||||
parent_span.in_scope(|| {
|
||||
response_tx
|
||||
.send(
|
||||
validate(
|
||||
request,
|
||||
tokenizer.as_ref(),
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
&mut rng,
|
||||
)
|
||||
.map_err(|err| {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
tracing::error!("{err}");
|
||||
err
|
||||
}),
|
||||
)
|
||||
.unwrap_or(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn validate(
|
||||
request: GenerateRequest,
|
||||
tokenizer: Option<&Tokenizer>,
|
||||
max_stop_sequences: usize,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
rng: &mut ThreadRng,
|
||||
) -> Result<ValidGenerateRequest, ValidationError> {
|
||||
let GenerateParameters {
|
||||
best_of,
|
||||
@ -239,16 +185,16 @@ fn validate(
|
||||
return Err(ValidationError::MaxNewTokens);
|
||||
}
|
||||
|
||||
if stop_sequences.len() > max_stop_sequences {
|
||||
if stop_sequences.len() > self.max_stop_sequences {
|
||||
return Err(ValidationError::StopSequence(
|
||||
max_stop_sequences,
|
||||
self.max_stop_sequences,
|
||||
stop_sequences.len(),
|
||||
));
|
||||
}
|
||||
|
||||
// If seed is None, assign a random one
|
||||
let seed = match seed {
|
||||
None => rng.gen(),
|
||||
None => thread_rng().gen(),
|
||||
Some(seed) => {
|
||||
if best_of > 1 {
|
||||
return Err(BestOfSeed);
|
||||
@ -265,51 +211,18 @@ fn validate(
|
||||
// Check if truncate is strictly positive and less than max_input_length
|
||||
let truncate = truncate
|
||||
.map(|value| {
|
||||
if value == 0 || value > max_input_length {
|
||||
return Err(ValidationError::Truncate(max_input_length, value));
|
||||
if value == 0 || value > self.max_input_length {
|
||||
return Err(ValidationError::Truncate(self.max_input_length, value));
|
||||
}
|
||||
Ok(Some(value))
|
||||
})
|
||||
.unwrap_or(Ok(None))?;
|
||||
|
||||
// If we have a fast tokenizer
|
||||
let inputs = if let Some(tokenizer) = tokenizer {
|
||||
// Get the number of tokens in the input
|
||||
let mut encoding = tokenizer
|
||||
.encode(request.inputs.clone(), true)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
// Validate inputs
|
||||
let inputs = self
|
||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||
.await?;
|
||||
|
||||
let (inputs, input_length) = if let Some(truncate) = truncate {
|
||||
// truncate encoding and decode new inputs
|
||||
encoding.truncate(truncate, 0, TruncationDirection::Left);
|
||||
let inputs = tokenizer
|
||||
.decode(Vec::from(encoding.get_ids()), false)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
(inputs, encoding.len())
|
||||
} else {
|
||||
(request.inputs, encoding.len())
|
||||
};
|
||||
|
||||
if input_length > max_input_length {
|
||||
return Err(ValidationError::InputLength(max_input_length, input_length));
|
||||
}
|
||||
|
||||
let total_tokens = input_length + max_new_tokens as usize;
|
||||
if total_tokens > max_total_tokens {
|
||||
return Err(ValidationError::MaxTotalTokens(
|
||||
max_total_tokens,
|
||||
input_length,
|
||||
max_new_tokens,
|
||||
));
|
||||
}
|
||||
|
||||
metrics::histogram!("tgi_request_input_length", input_length as f64);
|
||||
inputs
|
||||
} else {
|
||||
request.inputs
|
||||
};
|
||||
|
||||
// Return ValidGenerateRequest
|
||||
let parameters = NextTokenChooserParameters {
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
@ -330,15 +243,75 @@ fn validate(
|
||||
|
||||
Ok(ValidGenerateRequest {
|
||||
inputs,
|
||||
truncate: truncate.unwrap_or(max_input_length) as u32,
|
||||
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
||||
parameters,
|
||||
stopping_parameters,
|
||||
})
|
||||
}
|
||||
|
||||
type ValidationRequest = (
|
||||
GenerateRequest,
|
||||
oneshot::Sender<Result<ValidGenerateRequest, ValidationError>>,
|
||||
/// Validate the best_of parameter
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) fn validate_best_of(&self, best_of: usize) -> Result<usize, ValidationError> {
|
||||
if self.max_best_of == 1 && best_of != 1 {
|
||||
return Err(ValidationError::BestOfDisabled);
|
||||
}
|
||||
|
||||
if best_of > self.max_best_of {
|
||||
return Err(ValidationError::BestOf(self.max_best_of, best_of));
|
||||
}
|
||||
|
||||
Ok(best_of)
|
||||
}
|
||||
}
|
||||
|
||||
/// Start tokenization workers
|
||||
fn tokenizer_worker(tokenizer: Tokenizer, receiver: flume::Receiver<TokenizerRequest>) {
|
||||
// Loop over requests
|
||||
while let Ok(((inputs, truncate), response_tx, parent_span)) = receiver.recv() {
|
||||
parent_span.in_scope(|| {
|
||||
response_tx
|
||||
.send(prepare_input(inputs, truncate, &tokenizer).map_err(|err| {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
tracing::error!("{err}");
|
||||
err
|
||||
}))
|
||||
.unwrap_or(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Get input length and optionally truncate it
|
||||
fn prepare_input(
|
||||
inputs: String,
|
||||
truncate: Option<usize>,
|
||||
tokenizer: &Tokenizer,
|
||||
) -> Result<(String, usize), ValidationError> {
|
||||
// Get the number of tokens in the input
|
||||
let mut encoding = tokenizer
|
||||
.encode(inputs.clone(), true)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
|
||||
// Optionally truncate
|
||||
let (inputs, input_length) = match truncate {
|
||||
// Truncate is some and > encoding length
|
||||
Some(truncate) if truncate > encoding.len() => {
|
||||
// truncate encoding and decode new inputs
|
||||
encoding.truncate(truncate, 0, TruncationDirection::Left);
|
||||
let inputs = tokenizer
|
||||
.decode(Vec::from(encoding.get_ids()), false)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
(inputs, encoding.len())
|
||||
}
|
||||
// Nothing to do
|
||||
_ => (inputs, encoding.len()),
|
||||
};
|
||||
|
||||
Ok((inputs, input_length))
|
||||
}
|
||||
|
||||
type TokenizerRequest = (
|
||||
(String, Option<usize>),
|
||||
oneshot::Sender<Result<(String, usize), ValidationError>>,
|
||||
Span,
|
||||
);
|
||||
|
||||
|
@ -24,6 +24,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="Test",
|
||||
truncate=100,
|
||||
parameters=default_pb_parameters,
|
||||
stopping_parameters=default_pb_stop_parameters,
|
||||
)
|
||||
|
@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="Test",
|
||||
truncate=100,
|
||||
parameters=default_pb_parameters,
|
||||
stopping_parameters=default_pb_stop_parameters,
|
||||
)
|
||||
|
@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="def",
|
||||
truncate=100,
|
||||
parameters=default_pb_parameters,
|
||||
stopping_parameters=default_pb_stop_parameters,
|
||||
)
|
||||
@ -30,6 +31,7 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
|
||||
truncate=100,
|
||||
parameters=default_pb_parameters,
|
||||
stopping_parameters=default_pb_stop_parameters,
|
||||
)
|
||||
|
@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="Test",
|
||||
truncate=100,
|
||||
parameters=default_pb_parameters,
|
||||
stopping_parameters=default_pb_stop_parameters,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user