rework validation

This commit is contained in:
OlivierDehaene 2023-04-04 14:43:56 +02:00
parent 47e93409f3
commit 1dd2c24b9c
11 changed files with 424 additions and 485 deletions

356
Cargo.lock generated
View File

@ -8,18 +8,6 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "aes"
version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e8b47f52ea9bae42228d07ec09eb676433d7c4ed1ebdf0f1d1c29ed446f1ab8"
dependencies = [
"cfg-if",
"cipher",
"cpufeatures",
"opaque-debug",
]
[[package]]
name = "ahash"
version = "0.7.6"
@ -40,6 +28,15 @@ dependencies = [
"memchr",
]
[[package]]
name = "ansi_term"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2"
dependencies = [
"winapi",
]
[[package]]
name = "anstream"
version = "0.2.6"
@ -119,6 +116,17 @@ dependencies = [
"syn 2.0.12",
]
[[package]]
name = "atty"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
dependencies = [
"hermit-abi 0.1.19",
"libc",
"winapi",
]
[[package]]
name = "autocfg"
version = "1.1.0"
@ -202,12 +210,6 @@ version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a"
[[package]]
name = "base64ct"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
[[package]]
name = "bitflags"
version = "1.3.2"
@ -264,9 +266,9 @@ dependencies = [
[[package]]
name = "cached-path"
version = "0.6.1"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "097968e38f1319207f057d0f4d76452e4f4f847a5de61c5215379f297fa034f3"
checksum = "5f1c56d30236522ab3393a08746b138d4e16372001f42d29c88d513aeb8ab7ef"
dependencies = [
"flate2",
"fs2",
@ -281,7 +283,8 @@ dependencies = [
"tar",
"tempfile",
"thiserror",
"zip",
"zip 0.5.13",
"zip-extensions",
]
[[package]]
@ -289,9 +292,6 @@ name = "cc"
version = "1.0.79"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f"
dependencies = [
"jobserver",
]
[[package]]
name = "cfg-if"
@ -300,12 +300,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cipher"
version = "0.3.0"
name = "clap"
version = "2.34.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ee52072ec15386f770805afd189a01c8841be8696bed250fa2f13c4c0d6dfb7"
checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c"
dependencies = [
"generic-array",
"ansi_term",
"atty",
"bitflags",
"strsim 0.8.0",
"textwrap",
"unicode-width",
"vec_map",
]
[[package]]
@ -329,7 +335,7 @@ dependencies = [
"anstyle",
"bitflags",
"clap_lex",
"strsim",
"strsim 0.10.0",
]
[[package]]
@ -378,12 +384,6 @@ dependencies = [
"windows-sys 0.42.0",
]
[[package]]
name = "constant_time_eq"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc"
[[package]]
name = "core-foundation"
version = "0.9.3"
@ -483,9 +483,9 @@ dependencies = [
[[package]]
name = "darling"
version = "0.14.4"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850"
checksum = "0d706e75d87e35569db781a9b5e2416cff1236a47ed380831f959382ccd5f858"
dependencies = [
"darling_core",
"darling_macro",
@ -493,23 +493,23 @@ dependencies = [
[[package]]
name = "darling_core"
version = "0.14.4"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0"
checksum = "f0c960ae2da4de88a91b2d920c2a7233b400bc33cb28453a2987822d8392519b"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim",
"strsim 0.9.3",
"syn 1.0.109",
]
[[package]]
name = "darling_macro"
version = "0.14.4"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e"
checksum = "d9b5a2f4ac4969822c62224815d069952656cadc7084fdca9751e6d959189b72"
dependencies = [
"darling_core",
"quote",
@ -531,32 +531,26 @@ dependencies = [
[[package]]
name = "derive_builder"
version = "0.12.0"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8"
dependencies = [
"derive_builder_macro",
]
[[package]]
name = "derive_builder_core"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f"
checksum = "a2658621297f2cf68762a6f7dc0bb7e1ff2cfd6583daef8ee0fed6f7ec468ec0"
dependencies = [
"darling",
"derive_builder_core",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "derive_builder_macro"
version = "0.12.0"
name = "derive_builder_core"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e"
checksum = "2791ea3e372c8495c0bc2033991d76b512cd799d07491fbd6890124db9458bef"
dependencies = [
"derive_builder_core",
"darling",
"proc-macro2",
"quote",
"syn 1.0.109",
]
@ -568,7 +562,15 @@ checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f"
dependencies = [
"block-buffer",
"crypto-common",
"subtle",
]
[[package]]
name = "dirs"
version = "3.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30baa043103c9d0c2a57cf537cc2f35623889dc0d405e6c3cccfadbc81c71309"
dependencies = [
"dirs-sys",
]
[[package]]
@ -685,6 +687,19 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853"
[[package]]
name = "flume"
version = "0.10.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1657b4441c3403d9f7b3409e47575237dac27b1b5726df654a6ecbf92f0f7577"
dependencies = [
"futures-core",
"futures-sink",
"nanorand",
"pin-project",
"spin",
]
[[package]]
name = "fnv"
version = "1.0.7"
@ -831,8 +846,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
[[package]]
@ -885,6 +902,15 @@ version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]]
name = "hermit-abi"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
dependencies = [
"libc",
]
[[package]]
name = "hermit-abi"
version = "0.2.6"
@ -900,15 +926,6 @@ version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286"
[[package]]
name = "hmac"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
dependencies = [
"digest",
]
[[package]]
name = "http"
version = "0.2.9"
@ -1120,15 +1137,6 @@ version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6"
[[package]]
name = "jobserver"
version = "0.1.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "936cfd212a0155903bcbc060e316fb6cc7cbf2e1907329391ebadc1fe0ce77c2"
dependencies = [
"libc",
]
[[package]]
name = "js-sys"
version = "0.3.61"
@ -1331,33 +1339,21 @@ dependencies = [
"windows-sys 0.45.0",
]
[[package]]
name = "monostate"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0230b703f1ac35df1e24f6d0d2255472bcccaf657ecdfa4f1fcbcad1ad5bb98a"
dependencies = [
"monostate-impl",
"serde",
]
[[package]]
name = "monostate-impl"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8795add3e14028f11f8e848bd3294898a8294767b3776b6f733560d33bd2530b"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.12",
]
[[package]]
name = "multimap"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
[[package]]
name = "nanorand"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3"
dependencies = [
"getrandom",
]
[[package]]
name = "native-tls"
version = "0.2.11"
@ -1464,12 +1460,6 @@ dependencies = [
"pkg-config",
]
[[package]]
name = "opaque-debug"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
[[package]]
name = "openssl"
version = "0.10.48"
@ -1624,35 +1614,12 @@ dependencies = [
"windows-sys 0.45.0",
]
[[package]]
name = "password-hash"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700"
dependencies = [
"base64ct",
"rand_core",
"subtle",
]
[[package]]
name = "paste"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79"
[[package]]
name = "pbkdf2"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917"
dependencies = [
"digest",
"hmac",
"password-hash",
"sha2",
]
[[package]]
name = "percent-encoding"
version = "2.2.0"
@ -2165,17 +2132,6 @@ dependencies = [
"serde",
]
[[package]]
name = "sha1"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]]
name = "sha2"
version = "0.10.6"
@ -2202,7 +2158,7 @@ version = "2.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ccc8076840c4da029af4f87e4e8daeb0fca6b87bbb02e10cb60b791450e11e4"
dependencies = [
"dirs",
"dirs 4.0.0",
]
[[package]]
@ -2245,6 +2201,15 @@ dependencies = [
"winapi",
]
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
dependencies = [
"lock_api",
]
[[package]]
name = "spm_precompiled"
version = "0.1.4"
@ -2263,6 +2228,18 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "strsim"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a"
[[package]]
name = "strsim"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c"
[[package]]
name = "strsim"
version = "0.10.0"
@ -2279,12 +2256,6 @@ dependencies = [
"winapi",
]
[[package]]
name = "subtle"
version = "2.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601"
[[package]]
name = "syn"
version = "1.0.109"
@ -2358,7 +2329,7 @@ dependencies = [
name = "text-generation-launcher"
version = "0.4.3"
dependencies = [
"clap",
"clap 4.2.1",
"ctrlc",
"float_eq",
"reqwest",
@ -2376,14 +2347,14 @@ dependencies = [
"async-stream",
"axum",
"axum-tracing-opentelemetry",
"clap",
"clap 4.2.1",
"flume",
"futures",
"metrics",
"metrics-exporter-prometheus",
"nohash-hasher",
"opentelemetry",
"opentelemetry-otlp",
"parking_lot",
"rand",
"reqwest",
"serde",
@ -2392,7 +2363,6 @@ dependencies = [
"thiserror",
"tokenizers",
"tokio",
"tokio-stream",
"tower-http",
"tracing",
"tracing-opentelemetry",
@ -2401,6 +2371,15 @@ dependencies = [
"utoipa-swagger-ui",
]
[[package]]
name = "textwrap"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060"
dependencies = [
"unicode-width",
]
[[package]]
name = "thiserror"
version = "1.0.40"
@ -2433,20 +2412,14 @@ dependencies = [
[[package]]
name = "time"
version = "0.3.20"
version = "0.1.43"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd0cbfecb4d19b5ea75bb31ad904eb5b9fa13f21079c3b92017ebdf4999a5890"
checksum = "ca8a50ef2360fbd1eeb0ecd46795a87a19024eb4b53c5dc916ca1fd95fe62438"
dependencies = [
"serde",
"time-core",
"libc",
"winapi",
]
[[package]]
name = "time-core"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd"
[[package]]
name = "tinyvec"
version = "1.6.0"
@ -2465,13 +2438,14 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokenizers"
version = "0.13.2"
source = "git+https://github.com/huggingface/tokenizers.git#3aaf4946b3c82e7a04db6ecde7c8bb4e474e54af"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4ff2dd291eac98dcea13e8cf7a0b28c373a90dc9210ccdab0fa9e69ee0cac69"
dependencies = [
"aho-corasick",
"cached-path",
"clap",
"clap 2.34.0",
"derive_builder",
"dirs",
"dirs 3.0.2",
"esaxx-rs",
"getrandom",
"indicatif 0.15.0",
@ -2479,7 +2453,6 @@ dependencies = [
"lazy_static",
"log",
"macro_rules_attribute",
"monostate",
"onig",
"paste",
"rand",
@ -2901,7 +2874,7 @@ dependencies = [
"serde",
"serde_json",
"utoipa",
"zip",
"zip 0.6.4",
]
[[package]]
@ -2916,6 +2889,12 @@ version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]]
name = "vec_map"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191"
[[package]]
name = "version_check"
version = "0.9.4"
@ -3171,52 +3150,37 @@ dependencies = [
"libc",
]
[[package]]
name = "zip"
version = "0.5.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93ab48844d61251bb3835145c521d88aa4031d7139e8485990f60ca911fa0815"
dependencies = [
"byteorder",
"bzip2",
"crc32fast",
"flate2",
"thiserror",
"time",
]
[[package]]
name = "zip"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0445d0fbc924bb93539b4316c11afb121ea39296f99a3c4c9edad09e3658cdef"
dependencies = [
"aes",
"byteorder",
"bzip2",
"constant_time_eq",
"crc32fast",
"crossbeam-utils",
"flate2",
"hmac",
"pbkdf2",
"sha1",
"time",
"zstd",
]
[[package]]
name = "zstd"
version = "0.11.2+zstd.1.5.2"
name = "zip-extensions"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4"
checksum = "a64c3c977bc3434ce2d4bcea8ad3c644672de0f2c402b72b9171ca80a8885d14"
dependencies = [
"zstd-safe",
]
[[package]]
name = "zstd-safe"
version = "5.0.2+zstd.1.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db"
dependencies = [
"libc",
"zstd-sys",
]
[[package]]
name = "zstd-sys"
version = "2.0.7+zstd.1.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94509c3ba2fe55294d752b79842c530ccfab760192521df74a081a78d2b3c7f5"
dependencies = [
"cc",
"libc",
"pkg-config",
"zip 0.5.13",
]

View File

@ -27,8 +27,7 @@ serde = {version = "1.0.142", features = ["derive"]}
serde_json = "1.0"
text-generation-client = { path = "../router/client" }
thiserror = "1.0.38"
#tokenizers = "0.13.2"
tokenizers = { git = "https://github.com/huggingface/tokenizers.git" }
tokenizers = "0.13.2"
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tui = {package = "ratatui", version = "0.20", default-features = false, features = ["crossterm"]}
tracing = "0.1.37"

View File

@ -18,22 +18,20 @@ axum = { version = "0.6.4", features = ["json"] }
axum-tracing-opentelemetry = "0.9.0"
text-generation-client = { path = "client" }
clap = { version = "4.1.4", features = ["derive", "env"] }
flume = "0.10.14"
futures = "0.3.26"
metrics = "0.20.1"
metrics-exporter-prometheus = { version = "0.11.0", features = [] }
nohash-hasher = "0.2.0"
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.11.0"
parking_lot = "0.12.1"
rand = "0.8.5"
reqwest = { version = "0.11.14", features = [] }
serde = "1.0.152"
serde_json = "1.0.93"
thiserror = "1.0.38"
#tokenizers = "0.13.2"
tokenizers = { git = "https://github.com/huggingface/tokenizers.git" }
tokenizers = "0.13.2"
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.11"
tower-http = { version = "0.3.5", features = ["cors"] }
tracing = "0.1.37"
tracing-opentelemetry = "0.18.0"

View File

@ -2,17 +2,17 @@
use crate::validation::{Validation, ValidationError};
use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken};
use flume::r#async::RecvStream;
use futures::future::try_join_all;
use futures::stream::StreamExt;
use nohash_hasher::IntMap;
use std::sync::Arc;
use text_generation_client::{
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
};
use thiserror::Error;
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
use tokio::sync::{Notify, Semaphore, TryAcquireError};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
use tracing::{info_span, instrument, Instrument, Span};
/// Inference struct
@ -73,7 +73,7 @@ impl Infer {
pub(crate) async fn generate_stream(
&self,
request: GenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
) -> Result<RecvStream<Result<InferStreamResponse, InferError>>, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore
// This permit will live as long as Entry
let permit = self
@ -90,7 +90,7 @@ impl Infer {
let valid_request = self.validation.validate(request).await?;
// MPSC channel to communicate with the background batching task
let (response_tx, response_rx) = mpsc::unbounded_channel();
let (response_tx, response_rx) = flume::unbounded();
// Append the request to the queue
self.queue.append(Entry {
@ -108,7 +108,8 @@ impl Infer {
self.shared.batching_task.notify_one();
// Return stream
Ok(UnboundedReceiverStream::new(response_rx))
Ok(response_rx.into_stream())
// Ok(UnboundedReceiverStream::new(response_rx))
}
/// Add a new request to the queue and return a InferResponse

View File

@ -4,8 +4,7 @@ use crate::validation::ValidGenerateRequest;
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min;
use text_generation_client::{Batch, Request};
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit};
use tokio::sync::{oneshot, OwnedSemaphorePermit};
use tokio::time::Instant;
use tracing::{info_span, instrument, Span};
@ -15,7 +14,7 @@ pub(crate) struct Entry {
/// Request
pub request: ValidGenerateRequest,
/// Response sender to communicate between the Infer struct and the batching_task
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
pub response_tx: flume::Sender<Result<InferStreamResponse, InferError>>,
/// Span that will live as long as entry
pub span: Span,
/// Temporary span used as a guard when logging inference, wait times...
@ -32,13 +31,13 @@ pub(crate) struct Entry {
#[derive(Debug, Clone)]
pub(crate) struct Queue {
/// Channel to communicate with the background queue task
queue_sender: UnboundedSender<QueueCommand>,
queue_sender: flume::Sender<QueueCommand>,
}
impl Queue {
pub(crate) fn new() -> Self {
// Create channel
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
let (queue_sender, queue_receiver) = flume::unbounded();
// Launch background queue task
tokio::spawn(queue_task(queue_receiver));
@ -82,10 +81,10 @@ impl Queue {
}
// Background task responsible of the queue state
async fn queue_task(mut receiver: UnboundedReceiver<QueueCommand>) {
async fn queue_task(receiver: flume::Receiver<QueueCommand>) {
let mut state = State::new();
while let Some(cmd) = receiver.recv().await {
while let Ok(cmd) = receiver.recv_async().await {
match cmd {
QueueCommand::Append(entry, span) => span.in_scope(|| state.append(entry)),
QueueCommand::NextBatch {
@ -216,12 +215,12 @@ mod tests {
use super::*;
use std::sync::Arc;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use tokio::sync::{mpsc, Semaphore};
use tokio::sync::Semaphore;
use tracing::info_span;
fn default_entry() -> Entry {
let semaphore = Arc::new(Semaphore::new(1));
let (response_tx, _) = mpsc::unbounded_channel();
let (response_tx, _) = flume::unbounded();
let permit = semaphore.try_acquire_owned().unwrap();
Entry {

View File

@ -13,6 +13,7 @@ use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use axum::{http, Json, Router};
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
use futures::stream::StreamExt;
use futures::Stream;
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
use std::convert::Infallible;
@ -21,7 +22,6 @@ use text_generation_client::ShardedClient;
use tokenizers::Tokenizer;
use tokio::signal;
use tokio::time::Instant;
use tokio_stream::StreamExt;
use tower_http::cors::{AllowOrigin, CorsLayer};
use tracing::{info_span, instrument, Instrument};
use utoipa::OpenApi;

View File

@ -1,23 +1,24 @@
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
/// Payload validation logic
use crate::{GenerateParameters, GenerateRequest};
use rand::rngs::ThreadRng;
use rand::Rng;
use rand::{thread_rng, Rng};
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
use tokenizers::TruncationDirection;
use tokio::sync::{mpsc, oneshot};
use tokio::sync::oneshot;
use tracing::{instrument, Span};
/// Validation
#[derive(Debug, Clone)]
pub struct Validation {
/// maximum value for the best_of parameter
#[allow(dead_code)]
/// Validation parameters
max_best_of: usize,
/// Channel to communicate with the background validation task
sender: mpsc::UnboundedSender<ValidationRequest>,
max_stop_sequences: usize,
max_input_length: usize,
max_total_tokens: usize,
/// Channel to communicate with the background tokenization task
sender: Option<flume::Sender<TokenizerRequest>>,
}
impl Validation {
@ -29,22 +30,81 @@ impl Validation {
max_input_length: usize,
max_total_tokens: usize,
) -> Self {
// Create channel
let (validation_sender, validation_receiver) = mpsc::unbounded_channel();
// If we have a fast tokenizer
let sender = if let Some(tokenizer) = tokenizer {
// Create channel
let (validation_sender, validation_receiver) = flume::unbounded();
// Launch background validation task
tokio::spawn(validation_task(
workers,
tokenizer,
max_stop_sequences,
max_input_length,
max_total_tokens,
validation_receiver,
));
// Create workers
for _ in 0..workers {
let tokenizer_clone = tokenizer.clone();
let receiver_clone = validation_receiver.clone();
// Spawn worker
tokio::task::spawn_blocking(move || {
tokenizer_worker(tokenizer_clone.into(), receiver_clone)
});
}
Some(validation_sender)
} else {
None
};
Self {
max_best_of,
sender: validation_sender,
sender,
max_stop_sequences,
max_input_length,
max_total_tokens,
}
}
async fn validate_input(
&self,
inputs: String,
truncate: Option<usize>,
max_new_tokens: u32,
) -> Result<String, ValidationError> {
// If we have a fast tokenizer
if let Some(sender) = &self.sender {
// Create response channel
let (response_sender, response_receiver) = oneshot::channel();
// Send request to the background validation task
// Unwrap is safe here
sender
.send(((inputs, truncate), response_sender, Span::current()))
.unwrap();
// Await on response channel
// Unwrap is safe here
let (inputs, input_length) = response_receiver.await.unwrap()?;
// Get total tokens
let total_tokens = input_length + max_new_tokens as usize;
// Validate MaxTotalTokens
if total_tokens > self.max_total_tokens {
return Err(ValidationError::MaxTotalTokens(
self.max_total_tokens,
input_length,
max_new_tokens,
));
}
// Validate InputLength
if input_length > self.max_input_length {
return Err(ValidationError::InputLength(
self.max_input_length,
input_length,
));
}
metrics::histogram!("tgi_request_input_length", input_length as f64);
Ok(inputs)
}
// Return inputs without validation
else {
Ok(inputs)
}
}
@ -54,16 +114,139 @@ impl Validation {
&self,
request: GenerateRequest,
) -> Result<ValidGenerateRequest, ValidationError> {
// Create response channel
let (sender, receiver) = oneshot::channel();
// Send request to the background validation task
// Unwrap is safe here
self.sender
.send((request, sender, Span::current()))
.unwrap();
// Await on response channel
// Unwrap is safe here
receiver.await.unwrap()
let GenerateParameters {
best_of,
temperature,
repetition_penalty,
top_k,
top_p,
typical_p,
do_sample,
max_new_tokens,
stop: stop_sequences,
truncate,
seed,
watermark,
..
} = request.parameters;
// sampling must be true when best_of > 1
let best_of = best_of.unwrap_or(1);
let sampling = do_sample
|| temperature.is_some()
|| top_k.is_some()
|| top_p.is_some()
|| typical_p.is_some();
if best_of > 1 && !sampling {
return Err(BestOfSampling);
}
let temperature = temperature.unwrap_or(1.0);
if temperature <= 0.0 {
return Err(ValidationError::Temperature);
}
let repetition_penalty = repetition_penalty.unwrap_or(1.0);
if repetition_penalty <= 0.0 {
return Err(ValidationError::RepetitionPenalty);
}
// Different because the proto default value is not a valid value
// for the user
let top_p = top_p
.map(|value| {
if value <= 0.0 || value >= 1.0 {
return Err(ValidationError::TopP);
}
Ok(value)
})
.unwrap_or(Ok(1.0))?;
let typical_p = typical_p
.map(|value| {
if value <= 0.0 || value >= 1.0 {
return Err(ValidationError::TypicalP);
}
Ok(value)
})
.unwrap_or(Ok(1.0))?;
let top_k: u32 = top_k
.map(|value| {
if value <= 0 {
return Err(ValidationError::TopK);
}
Ok(value as u32)
})
.unwrap_or(Ok(0))?;
if max_new_tokens == 0 {
return Err(ValidationError::MaxNewTokens);
}
if stop_sequences.len() > self.max_stop_sequences {
return Err(ValidationError::StopSequence(
self.max_stop_sequences,
stop_sequences.len(),
));
}
// If seed is None, assign a random one
let seed = match seed {
None => thread_rng().gen(),
Some(seed) => {
if best_of > 1 {
return Err(BestOfSeed);
}
seed
}
};
// Check if inputs is empty
if request.inputs.is_empty() {
return Err(EmptyInput);
}
// Check if truncate is strictly positive and less than max_input_length
let truncate = truncate
.map(|value| {
if value == 0 || value > self.max_input_length {
return Err(ValidationError::Truncate(self.max_input_length, value));
}
Ok(Some(value))
})
.unwrap_or(Ok(None))?;
// Validate inputs
let inputs = self
.validate_input(request.inputs, truncate, max_new_tokens)
.await?;
let parameters = NextTokenChooserParameters {
temperature,
repetition_penalty,
top_k,
top_p,
typical_p,
do_sample,
seed,
watermark,
};
let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens,
stop_sequences,
ignore_eos_token: false,
};
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);
Ok(ValidGenerateRequest {
inputs,
truncate: truncate.unwrap_or(self.max_input_length) as u32,
parameters,
stopping_parameters,
})
}
/// Validate the best_of parameter
@ -81,264 +264,54 @@ impl Validation {
}
}
/// Validation task
/// Load balance the validation requests between multiple validation workers
async fn validation_task(
workers: usize,
tokenizer: Option<Tokenizer>,
max_stop_sequences: usize,
max_input_length: usize,
max_total_tokens: usize,
mut receiver: mpsc::UnboundedReceiver<ValidationRequest>,
) {
let mut workers_senders = Vec::with_capacity(workers);
// Create workers
for _ in 0..workers {
let tokenizer_clone: Option<Tokenizer> = tokenizer.clone().into();
// Create channel to communicate with worker
let (worker_sender, worker_receiver) = mpsc::channel(workers);
workers_senders.push(worker_sender);
// Spawn worker
tokio::task::spawn_blocking(move || {
validation_worker(
tokenizer_clone,
max_stop_sequences,
max_input_length,
max_total_tokens,
worker_receiver,
)
});
}
loop {
// Load balance requests between workers
for sender in workers_senders.iter() {
if let Some(validation_request) = receiver.recv().await {
sender.send(validation_request).await.unwrap();
} else {
return;
}
}
}
}
/// Check the parameters inside the payload and get the number of tokens inside the input using
/// the tokenizer
fn validation_worker(
tokenizer: Option<Tokenizer>,
max_stop_sequences: usize,
max_input_length: usize,
max_total_tokens: usize,
mut receiver: mpsc::Receiver<ValidationRequest>,
) {
// Seed rng
let mut rng = rand::thread_rng();
/// Start tokenization workers
fn tokenizer_worker(tokenizer: Tokenizer, receiver: flume::Receiver<TokenizerRequest>) {
// Loop over requests
while let Some((request, response_tx, parent_span)) = receiver.blocking_recv() {
while let Ok(((inputs, truncate), response_tx, parent_span)) = receiver.recv() {
parent_span.in_scope(|| {
response_tx
.send(
validate(
request,
tokenizer.as_ref(),
max_stop_sequences,
max_input_length,
max_total_tokens,
&mut rng,
)
.map_err(|err| {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
err
}),
)
.send(prepare_input(inputs, truncate, &tokenizer).map_err(|err| {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
err
}))
.unwrap_or(())
})
}
}
fn validate(
request: GenerateRequest,
tokenizer: Option<&Tokenizer>,
max_stop_sequences: usize,
max_input_length: usize,
max_total_tokens: usize,
rng: &mut ThreadRng,
) -> Result<ValidGenerateRequest, ValidationError> {
let GenerateParameters {
best_of,
temperature,
repetition_penalty,
top_k,
top_p,
typical_p,
do_sample,
max_new_tokens,
stop: stop_sequences,
truncate,
seed,
watermark,
..
} = request.parameters;
/// Get input length and optionally truncate it
fn prepare_input(
inputs: String,
truncate: Option<usize>,
tokenizer: &Tokenizer,
) -> Result<(String, usize), ValidationError> {
// Get the number of tokens in the input
let mut encoding = tokenizer
.encode(inputs.clone(), true)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
// sampling must be true when best_of > 1
let best_of = best_of.unwrap_or(1);
let sampling = do_sample
|| temperature.is_some()
|| top_k.is_some()
|| top_p.is_some()
|| typical_p.is_some();
if best_of > 1 && !sampling {
return Err(BestOfSampling);
}
let temperature = temperature.unwrap_or(1.0);
if temperature <= 0.0 {
return Err(ValidationError::Temperature);
}
let repetition_penalty = repetition_penalty.unwrap_or(1.0);
if repetition_penalty <= 0.0 {
return Err(ValidationError::RepetitionPenalty);
}
// Different because the proto default value is not a valid value
// for the user
let top_p = top_p
.map(|value| {
if value <= 0.0 || value >= 1.0 {
return Err(ValidationError::TopP);
}
Ok(value)
})
.unwrap_or(Ok(1.0))?;
let typical_p = typical_p
.map(|value| {
if value <= 0.0 || value >= 1.0 {
return Err(ValidationError::TypicalP);
}
Ok(value)
})
.unwrap_or(Ok(1.0))?;
let top_k: u32 = top_k
.map(|value| {
if value <= 0 {
return Err(ValidationError::TopK);
}
Ok(value as u32)
})
.unwrap_or(Ok(0))?;
if max_new_tokens == 0 {
return Err(ValidationError::MaxNewTokens);
}
if stop_sequences.len() > max_stop_sequences {
return Err(ValidationError::StopSequence(
max_stop_sequences,
stop_sequences.len(),
));
}
// If seed is None, assign a random one
let seed = match seed {
None => rng.gen(),
Some(seed) => {
if best_of > 1 {
return Err(BestOfSeed);
}
seed
}
};
// Check if inputs is empty
if request.inputs.is_empty() {
return Err(EmptyInput);
}
// Check if truncate is strictly positive and less than max_input_length
let truncate = truncate
.map(|value| {
if value == 0 || value > max_input_length {
return Err(ValidationError::Truncate(max_input_length, value));
}
Ok(Some(value))
})
.unwrap_or(Ok(None))?;
// If we have a fast tokenizer
let inputs = if let Some(tokenizer) = tokenizer {
// Get the number of tokens in the input
let mut encoding = tokenizer
.encode(request.inputs.clone(), true)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
let (inputs, input_length) = if let Some(truncate) = truncate {
// Optionally truncate
let (inputs, input_length) = match truncate {
// Truncate is some and > encoding length
Some(truncate) if truncate > encoding.len() => {
// truncate encoding and decode new inputs
encoding.truncate(truncate, 0, TruncationDirection::Left);
let inputs = tokenizer
.decode(Vec::from(encoding.get_ids()), false)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
(inputs, encoding.len())
} else {
(request.inputs, encoding.len())
};
if input_length > max_input_length {
return Err(ValidationError::InputLength(max_input_length, input_length));
}
let total_tokens = input_length + max_new_tokens as usize;
if total_tokens > max_total_tokens {
return Err(ValidationError::MaxTotalTokens(
max_total_tokens,
input_length,
max_new_tokens,
));
}
metrics::histogram!("tgi_request_input_length", input_length as f64);
inputs
} else {
request.inputs
// Nothing to do
_ => (inputs, encoding.len()),
};
// Return ValidGenerateRequest
let parameters = NextTokenChooserParameters {
temperature,
repetition_penalty,
top_k,
top_p,
typical_p,
do_sample,
seed,
watermark,
};
let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens,
stop_sequences,
ignore_eos_token: false,
};
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);
Ok(ValidGenerateRequest {
inputs,
truncate: truncate.unwrap_or(max_input_length) as u32,
parameters,
stopping_parameters,
})
Ok((inputs, input_length))
}
type ValidationRequest = (
GenerateRequest,
oneshot::Sender<Result<ValidGenerateRequest, ValidationError>>,
type TokenizerRequest = (
(String, Option<usize>),
oneshot::Sender<Result<(String, usize), ValidationError>>,
Span,
);

View File

@ -24,6 +24,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)

View File

@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)

View File

@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="def",
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@ -30,6 +31,7 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)

View File

@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)