diff --git a/Cargo.lock b/Cargo.lock index ca60f997..fd4e133c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,18 +8,6 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" -[[package]] -name = "aes" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e8b47f52ea9bae42228d07ec09eb676433d7c4ed1ebdf0f1d1c29ed446f1ab8" -dependencies = [ - "cfg-if", - "cipher", - "cpufeatures", - "opaque-debug", -] - [[package]] name = "ahash" version = "0.7.6" @@ -40,6 +28,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "ansi_term" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" +dependencies = [ + "winapi", +] + [[package]] name = "anstream" version = "0.2.6" @@ -119,6 +116,17 @@ dependencies = [ "syn 2.0.12", ] +[[package]] +name = "atty" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +dependencies = [ + "hermit-abi 0.1.19", + "libc", + "winapi", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -202,12 +210,6 @@ version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a" -[[package]] -name = "base64ct" -version = "1.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" - [[package]] name = "bitflags" version = "1.3.2" @@ -264,9 +266,9 @@ dependencies = [ [[package]] name = "cached-path" -version = "0.6.1" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "097968e38f1319207f057d0f4d76452e4f4f847a5de61c5215379f297fa034f3" +checksum = "5f1c56d30236522ab3393a08746b138d4e16372001f42d29c88d513aeb8ab7ef" dependencies = [ "flate2", "fs2", @@ -281,7 +283,8 @@ dependencies = [ "tar", "tempfile", "thiserror", - "zip", + "zip 0.5.13", + "zip-extensions", ] [[package]] @@ -289,9 +292,6 @@ name = "cc" version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" -dependencies = [ - "jobserver", -] [[package]] name = "cfg-if" @@ -300,12 +300,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] -name = "cipher" -version = "0.3.0" +name = "clap" +version = "2.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ee52072ec15386f770805afd189a01c8841be8696bed250fa2f13c4c0d6dfb7" +checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c" dependencies = [ - "generic-array", + "ansi_term", + "atty", + "bitflags", + "strsim 0.8.0", + "textwrap", + "unicode-width", + "vec_map", ] [[package]] @@ -329,7 +335,7 @@ dependencies = [ "anstyle", "bitflags", "clap_lex", - "strsim", + "strsim 0.10.0", ] [[package]] @@ -378,12 +384,6 @@ dependencies = [ "windows-sys 0.42.0", ] -[[package]] -name = "constant_time_eq" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" - [[package]] name = "core-foundation" version = "0.9.3" @@ -483,9 +483,9 @@ dependencies = [ [[package]] name = "darling" -version = "0.14.4" +version = "0.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850" +checksum = "0d706e75d87e35569db781a9b5e2416cff1236a47ed380831f959382ccd5f858" dependencies = [ "darling_core", "darling_macro", @@ -493,23 +493,23 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.14.4" +version = "0.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0" +checksum = "f0c960ae2da4de88a91b2d920c2a7233b400bc33cb28453a2987822d8392519b" dependencies = [ "fnv", "ident_case", "proc-macro2", "quote", - "strsim", + "strsim 0.9.3", "syn 1.0.109", ] [[package]] name = "darling_macro" -version = "0.14.4" +version = "0.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" +checksum = "d9b5a2f4ac4969822c62224815d069952656cadc7084fdca9751e6d959189b72" dependencies = [ "darling_core", "quote", @@ -531,32 +531,26 @@ dependencies = [ [[package]] name = "derive_builder" -version = "0.12.0" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8" -dependencies = [ - "derive_builder_macro", -] - -[[package]] -name = "derive_builder_core" -version = "0.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f" +checksum = "a2658621297f2cf68762a6f7dc0bb7e1ff2cfd6583daef8ee0fed6f7ec468ec0" dependencies = [ "darling", + "derive_builder_core", "proc-macro2", "quote", "syn 1.0.109", ] [[package]] -name = "derive_builder_macro" -version = "0.12.0" +name = "derive_builder_core" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e" +checksum = "2791ea3e372c8495c0bc2033991d76b512cd799d07491fbd6890124db9458bef" dependencies = [ - "derive_builder_core", + "darling", + "proc-macro2", + "quote", "syn 1.0.109", ] @@ -568,7 +562,15 @@ checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f" dependencies = [ "block-buffer", "crypto-common", - "subtle", +] + +[[package]] +name = "dirs" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30baa043103c9d0c2a57cf537cc2f35623889dc0d405e6c3cccfadbc81c71309" +dependencies = [ + "dirs-sys", ] [[package]] @@ -685,6 +687,19 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853" +[[package]] +name = "flume" +version = "0.10.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1657b4441c3403d9f7b3409e47575237dac27b1b5726df654a6ecbf92f0f7577" +dependencies = [ + "futures-core", + "futures-sink", + "nanorand", + "pin-project", + "spin", +] + [[package]] name = "fnv" version = "1.0.7" @@ -831,8 +846,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi 0.11.0+wasi-snapshot-preview1", + "wasm-bindgen", ] [[package]] @@ -885,6 +902,15 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +[[package]] +name = "hermit-abi" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +dependencies = [ + "libc", +] + [[package]] name = "hermit-abi" version = "0.2.6" @@ -900,15 +926,6 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" -[[package]] -name = "hmac" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" -dependencies = [ - "digest", -] - [[package]] name = "http" version = "0.2.9" @@ -1120,15 +1137,6 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" -[[package]] -name = "jobserver" -version = "0.1.26" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "936cfd212a0155903bcbc060e316fb6cc7cbf2e1907329391ebadc1fe0ce77c2" -dependencies = [ - "libc", -] - [[package]] name = "js-sys" version = "0.3.61" @@ -1331,33 +1339,21 @@ dependencies = [ "windows-sys 0.45.0", ] -[[package]] -name = "monostate" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0230b703f1ac35df1e24f6d0d2255472bcccaf657ecdfa4f1fcbcad1ad5bb98a" -dependencies = [ - "monostate-impl", - "serde", -] - -[[package]] -name = "monostate-impl" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8795add3e14028f11f8e848bd3294898a8294767b3776b6f733560d33bd2530b" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.12", -] - [[package]] name = "multimap" version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" +[[package]] +name = "nanorand" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" +dependencies = [ + "getrandom", +] + [[package]] name = "native-tls" version = "0.2.11" @@ -1464,12 +1460,6 @@ dependencies = [ "pkg-config", ] -[[package]] -name = "opaque-debug" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" - [[package]] name = "openssl" version = "0.10.48" @@ -1624,35 +1614,12 @@ dependencies = [ "windows-sys 0.45.0", ] -[[package]] -name = "password-hash" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700" -dependencies = [ - "base64ct", - "rand_core", - "subtle", -] - [[package]] name = "paste" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79" -[[package]] -name = "pbkdf2" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917" -dependencies = [ - "digest", - "hmac", - "password-hash", - "sha2", -] - [[package]] name = "percent-encoding" version = "2.2.0" @@ -2165,17 +2132,6 @@ dependencies = [ "serde", ] -[[package]] -name = "sha1" -version = "0.10.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - [[package]] name = "sha2" version = "0.10.6" @@ -2202,7 +2158,7 @@ version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ccc8076840c4da029af4f87e4e8daeb0fca6b87bbb02e10cb60b791450e11e4" dependencies = [ - "dirs", + "dirs 4.0.0", ] [[package]] @@ -2245,6 +2201,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + [[package]] name = "spm_precompiled" version = "0.1.4" @@ -2263,6 +2228,18 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "strsim" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" + +[[package]] +name = "strsim" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c" + [[package]] name = "strsim" version = "0.10.0" @@ -2279,12 +2256,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "subtle" -version = "2.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" - [[package]] name = "syn" version = "1.0.109" @@ -2358,7 +2329,7 @@ dependencies = [ name = "text-generation-launcher" version = "0.4.3" dependencies = [ - "clap", + "clap 4.2.1", "ctrlc", "float_eq", "reqwest", @@ -2376,14 +2347,14 @@ dependencies = [ "async-stream", "axum", "axum-tracing-opentelemetry", - "clap", + "clap 4.2.1", + "flume", "futures", "metrics", "metrics-exporter-prometheus", "nohash-hasher", "opentelemetry", "opentelemetry-otlp", - "parking_lot", "rand", "reqwest", "serde", @@ -2392,7 +2363,6 @@ dependencies = [ "thiserror", "tokenizers", "tokio", - "tokio-stream", "tower-http", "tracing", "tracing-opentelemetry", @@ -2401,6 +2371,15 @@ dependencies = [ "utoipa-swagger-ui", ] +[[package]] +name = "textwrap" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" +dependencies = [ + "unicode-width", +] + [[package]] name = "thiserror" version = "1.0.40" @@ -2433,20 +2412,14 @@ dependencies = [ [[package]] name = "time" -version = "0.3.20" +version = "0.1.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd0cbfecb4d19b5ea75bb31ad904eb5b9fa13f21079c3b92017ebdf4999a5890" +checksum = "ca8a50ef2360fbd1eeb0ecd46795a87a19024eb4b53c5dc916ca1fd95fe62438" dependencies = [ - "serde", - "time-core", + "libc", + "winapi", ] -[[package]] -name = "time-core" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd" - [[package]] name = "tinyvec" version = "1.6.0" @@ -2465,13 +2438,14 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokenizers" version = "0.13.2" -source = "git+https://github.com/huggingface/tokenizers.git#3aaf4946b3c82e7a04db6ecde7c8bb4e474e54af" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4ff2dd291eac98dcea13e8cf7a0b28c373a90dc9210ccdab0fa9e69ee0cac69" dependencies = [ "aho-corasick", "cached-path", - "clap", + "clap 2.34.0", "derive_builder", - "dirs", + "dirs 3.0.2", "esaxx-rs", "getrandom", "indicatif 0.15.0", @@ -2479,7 +2453,6 @@ dependencies = [ "lazy_static", "log", "macro_rules_attribute", - "monostate", "onig", "paste", "rand", @@ -2901,7 +2874,7 @@ dependencies = [ "serde", "serde_json", "utoipa", - "zip", + "zip 0.6.4", ] [[package]] @@ -2916,6 +2889,12 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" +[[package]] +name = "vec_map" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191" + [[package]] name = "version_check" version = "0.9.4" @@ -3171,52 +3150,37 @@ dependencies = [ "libc", ] +[[package]] +name = "zip" +version = "0.5.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93ab48844d61251bb3835145c521d88aa4031d7139e8485990f60ca911fa0815" +dependencies = [ + "byteorder", + "bzip2", + "crc32fast", + "flate2", + "thiserror", + "time", +] + [[package]] name = "zip" version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0445d0fbc924bb93539b4316c11afb121ea39296f99a3c4c9edad09e3658cdef" dependencies = [ - "aes", "byteorder", - "bzip2", - "constant_time_eq", "crc32fast", "crossbeam-utils", "flate2", - "hmac", - "pbkdf2", - "sha1", - "time", - "zstd", ] [[package]] -name = "zstd" -version = "0.11.2+zstd.1.5.2" +name = "zip-extensions" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" +checksum = "a64c3c977bc3434ce2d4bcea8ad3c644672de0f2c402b72b9171ca80a8885d14" dependencies = [ - "zstd-safe", -] - -[[package]] -name = "zstd-safe" -version = "5.0.2+zstd.1.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db" -dependencies = [ - "libc", - "zstd-sys", -] - -[[package]] -name = "zstd-sys" -version = "2.0.7+zstd.1.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94509c3ba2fe55294d752b79842c530ccfab760192521df74a081a78d2b3c7f5" -dependencies = [ - "cc", - "libc", - "pkg-config", + "zip 0.5.13", ] diff --git a/benchmark/Cargo.toml b/benchmark/Cargo.toml index f2a82935..d3badcd8 100644 --- a/benchmark/Cargo.toml +++ b/benchmark/Cargo.toml @@ -27,8 +27,7 @@ serde = {version = "1.0.142", features = ["derive"]} serde_json = "1.0" text-generation-client = { path = "../router/client" } thiserror = "1.0.38" -#tokenizers = "0.13.2" -tokenizers = { git = "https://github.com/huggingface/tokenizers.git" } +tokenizers = "0.13.2" tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tui = {package = "ratatui", version = "0.20", default-features = false, features = ["crossterm"]} tracing = "0.1.37" diff --git a/router/Cargo.toml b/router/Cargo.toml index 801c647b..c971353b 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -18,22 +18,20 @@ axum = { version = "0.6.4", features = ["json"] } axum-tracing-opentelemetry = "0.9.0" text-generation-client = { path = "client" } clap = { version = "4.1.4", features = ["derive", "env"] } +flume = "0.10.14" futures = "0.3.26" metrics = "0.20.1" metrics-exporter-prometheus = { version = "0.11.0", features = [] } nohash-hasher = "0.2.0" opentelemetry = { version = "0.18.0", features = ["rt-tokio"] } opentelemetry-otlp = "0.11.0" -parking_lot = "0.12.1" rand = "0.8.5" reqwest = { version = "0.11.14", features = [] } serde = "1.0.152" serde_json = "1.0.93" thiserror = "1.0.38" -#tokenizers = "0.13.2" -tokenizers = { git = "https://github.com/huggingface/tokenizers.git" } +tokenizers = "0.13.2" tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } -tokio-stream = "0.1.11" tower-http = { version = "0.3.5", features = ["cors"] } tracing = "0.1.37" tracing-opentelemetry = "0.18.0" diff --git a/router/src/infer.rs b/router/src/infer.rs index 5eafc3e9..61cc7069 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -2,17 +2,17 @@ use crate::validation::{Validation, ValidationError}; use crate::{Entry, Queue, Token}; use crate::{GenerateRequest, PrefillToken}; +use flume::r#async::RecvStream; use futures::future::try_join_all; +use futures::stream::StreamExt; use nohash_hasher::IntMap; use std::sync::Arc; use text_generation_client::{ Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, }; use thiserror::Error; -use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError}; +use tokio::sync::{Notify, Semaphore, TryAcquireError}; use tokio::time::Instant; -use tokio_stream::wrappers::UnboundedReceiverStream; -use tokio_stream::StreamExt; use tracing::{info_span, instrument, Instrument, Span}; /// Inference struct @@ -73,7 +73,7 @@ impl Infer { pub(crate) async fn generate_stream( &self, request: GenerateRequest, - ) -> Result>, InferError> { + ) -> Result>, InferError> { // Limit concurrent requests by acquiring a permit from the semaphore // This permit will live as long as Entry let permit = self @@ -90,7 +90,7 @@ impl Infer { let valid_request = self.validation.validate(request).await?; // MPSC channel to communicate with the background batching task - let (response_tx, response_rx) = mpsc::unbounded_channel(); + let (response_tx, response_rx) = flume::unbounded(); // Append the request to the queue self.queue.append(Entry { @@ -108,7 +108,8 @@ impl Infer { self.shared.batching_task.notify_one(); // Return stream - Ok(UnboundedReceiverStream::new(response_rx)) + Ok(response_rx.into_stream()) + // Ok(UnboundedReceiverStream::new(response_rx)) } /// Add a new request to the queue and return a InferResponse diff --git a/router/src/queue.rs b/router/src/queue.rs index 77f8461b..11eb7f59 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -4,8 +4,7 @@ use crate::validation::ValidGenerateRequest; use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::min; use text_generation_client::{Batch, Request}; -use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; -use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit}; +use tokio::sync::{oneshot, OwnedSemaphorePermit}; use tokio::time::Instant; use tracing::{info_span, instrument, Span}; @@ -15,7 +14,7 @@ pub(crate) struct Entry { /// Request pub request: ValidGenerateRequest, /// Response sender to communicate between the Infer struct and the batching_task - pub response_tx: UnboundedSender>, + pub response_tx: flume::Sender>, /// Span that will live as long as entry pub span: Span, /// Temporary span used as a guard when logging inference, wait times... @@ -32,13 +31,13 @@ pub(crate) struct Entry { #[derive(Debug, Clone)] pub(crate) struct Queue { /// Channel to communicate with the background queue task - queue_sender: UnboundedSender, + queue_sender: flume::Sender, } impl Queue { pub(crate) fn new() -> Self { // Create channel - let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); + let (queue_sender, queue_receiver) = flume::unbounded(); // Launch background queue task tokio::spawn(queue_task(queue_receiver)); @@ -82,10 +81,10 @@ impl Queue { } // Background task responsible of the queue state -async fn queue_task(mut receiver: UnboundedReceiver) { +async fn queue_task(receiver: flume::Receiver) { let mut state = State::new(); - while let Some(cmd) = receiver.recv().await { + while let Ok(cmd) = receiver.recv_async().await { match cmd { QueueCommand::Append(entry, span) => span.in_scope(|| state.append(entry)), QueueCommand::NextBatch { @@ -216,12 +215,12 @@ mod tests { use super::*; use std::sync::Arc; use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; - use tokio::sync::{mpsc, Semaphore}; + use tokio::sync::Semaphore; use tracing::info_span; fn default_entry() -> Entry { let semaphore = Arc::new(Semaphore::new(1)); - let (response_tx, _) = mpsc::unbounded_channel(); + let (response_tx, _) = flume::unbounded(); let permit = semaphore.try_acquire_owned().unwrap(); Entry { diff --git a/router/src/server.rs b/router/src/server.rs index 88c40565..55aca4b3 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -13,6 +13,7 @@ use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; use axum::{http, Json, Router}; use axum_tracing_opentelemetry::opentelemetry_tracing_layer; +use futures::stream::StreamExt; use futures::Stream; use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle}; use std::convert::Infallible; @@ -21,7 +22,6 @@ use text_generation_client::ShardedClient; use tokenizers::Tokenizer; use tokio::signal; use tokio::time::Instant; -use tokio_stream::StreamExt; use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing::{info_span, instrument, Instrument}; use utoipa::OpenApi; diff --git a/router/src/validation.rs b/router/src/validation.rs index a0b8b98e..b781210c 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,23 +1,24 @@ use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; /// Payload validation logic use crate::{GenerateParameters, GenerateRequest}; -use rand::rngs::ThreadRng; -use rand::Rng; +use rand::{thread_rng, Rng}; use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokenizers::TruncationDirection; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::oneshot; use tracing::{instrument, Span}; /// Validation #[derive(Debug, Clone)] pub struct Validation { - /// maximum value for the best_of parameter - #[allow(dead_code)] + /// Validation parameters max_best_of: usize, - /// Channel to communicate with the background validation task - sender: mpsc::UnboundedSender, + max_stop_sequences: usize, + max_input_length: usize, + max_total_tokens: usize, + /// Channel to communicate with the background tokenization task + sender: Option>, } impl Validation { @@ -29,22 +30,81 @@ impl Validation { max_input_length: usize, max_total_tokens: usize, ) -> Self { - // Create channel - let (validation_sender, validation_receiver) = mpsc::unbounded_channel(); + // If we have a fast tokenizer + let sender = if let Some(tokenizer) = tokenizer { + // Create channel + let (validation_sender, validation_receiver) = flume::unbounded(); - // Launch background validation task - tokio::spawn(validation_task( - workers, - tokenizer, - max_stop_sequences, - max_input_length, - max_total_tokens, - validation_receiver, - )); + // Create workers + for _ in 0..workers { + let tokenizer_clone = tokenizer.clone(); + let receiver_clone = validation_receiver.clone(); + + // Spawn worker + tokio::task::spawn_blocking(move || { + tokenizer_worker(tokenizer_clone.into(), receiver_clone) + }); + } + Some(validation_sender) + } else { + None + }; Self { max_best_of, - sender: validation_sender, + sender, + max_stop_sequences, + max_input_length, + max_total_tokens, + } + } + + async fn validate_input( + &self, + inputs: String, + truncate: Option, + max_new_tokens: u32, + ) -> Result { + // If we have a fast tokenizer + if let Some(sender) = &self.sender { + // Create response channel + let (response_sender, response_receiver) = oneshot::channel(); + // Send request to the background validation task + // Unwrap is safe here + sender + .send(((inputs, truncate), response_sender, Span::current())) + .unwrap(); + + // Await on response channel + // Unwrap is safe here + let (inputs, input_length) = response_receiver.await.unwrap()?; + + // Get total tokens + let total_tokens = input_length + max_new_tokens as usize; + + // Validate MaxTotalTokens + if total_tokens > self.max_total_tokens { + return Err(ValidationError::MaxTotalTokens( + self.max_total_tokens, + input_length, + max_new_tokens, + )); + } + + // Validate InputLength + if input_length > self.max_input_length { + return Err(ValidationError::InputLength( + self.max_input_length, + input_length, + )); + } + + metrics::histogram!("tgi_request_input_length", input_length as f64); + Ok(inputs) + } + // Return inputs without validation + else { + Ok(inputs) } } @@ -54,16 +114,139 @@ impl Validation { &self, request: GenerateRequest, ) -> Result { - // Create response channel - let (sender, receiver) = oneshot::channel(); - // Send request to the background validation task - // Unwrap is safe here - self.sender - .send((request, sender, Span::current())) - .unwrap(); - // Await on response channel - // Unwrap is safe here - receiver.await.unwrap() + let GenerateParameters { + best_of, + temperature, + repetition_penalty, + top_k, + top_p, + typical_p, + do_sample, + max_new_tokens, + stop: stop_sequences, + truncate, + seed, + watermark, + .. + } = request.parameters; + + // sampling must be true when best_of > 1 + let best_of = best_of.unwrap_or(1); + let sampling = do_sample + || temperature.is_some() + || top_k.is_some() + || top_p.is_some() + || typical_p.is_some(); + + if best_of > 1 && !sampling { + return Err(BestOfSampling); + } + + let temperature = temperature.unwrap_or(1.0); + if temperature <= 0.0 { + return Err(ValidationError::Temperature); + } + + let repetition_penalty = repetition_penalty.unwrap_or(1.0); + if repetition_penalty <= 0.0 { + return Err(ValidationError::RepetitionPenalty); + } + + // Different because the proto default value is not a valid value + // for the user + let top_p = top_p + .map(|value| { + if value <= 0.0 || value >= 1.0 { + return Err(ValidationError::TopP); + } + Ok(value) + }) + .unwrap_or(Ok(1.0))?; + + let typical_p = typical_p + .map(|value| { + if value <= 0.0 || value >= 1.0 { + return Err(ValidationError::TypicalP); + } + Ok(value) + }) + .unwrap_or(Ok(1.0))?; + + let top_k: u32 = top_k + .map(|value| { + if value <= 0 { + return Err(ValidationError::TopK); + } + Ok(value as u32) + }) + .unwrap_or(Ok(0))?; + + if max_new_tokens == 0 { + return Err(ValidationError::MaxNewTokens); + } + + if stop_sequences.len() > self.max_stop_sequences { + return Err(ValidationError::StopSequence( + self.max_stop_sequences, + stop_sequences.len(), + )); + } + + // If seed is None, assign a random one + let seed = match seed { + None => thread_rng().gen(), + Some(seed) => { + if best_of > 1 { + return Err(BestOfSeed); + } + seed + } + }; + + // Check if inputs is empty + if request.inputs.is_empty() { + return Err(EmptyInput); + } + + // Check if truncate is strictly positive and less than max_input_length + let truncate = truncate + .map(|value| { + if value == 0 || value > self.max_input_length { + return Err(ValidationError::Truncate(self.max_input_length, value)); + } + Ok(Some(value)) + }) + .unwrap_or(Ok(None))?; + + // Validate inputs + let inputs = self + .validate_input(request.inputs, truncate, max_new_tokens) + .await?; + + let parameters = NextTokenChooserParameters { + temperature, + repetition_penalty, + top_k, + top_p, + typical_p, + do_sample, + seed, + watermark, + }; + let stopping_parameters = StoppingCriteriaParameters { + max_new_tokens, + stop_sequences, + ignore_eos_token: false, + }; + + metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64); + + Ok(ValidGenerateRequest { + inputs, + truncate: truncate.unwrap_or(self.max_input_length) as u32, + parameters, + stopping_parameters, + }) } /// Validate the best_of parameter @@ -81,264 +264,54 @@ impl Validation { } } -/// Validation task -/// Load balance the validation requests between multiple validation workers -async fn validation_task( - workers: usize, - tokenizer: Option, - max_stop_sequences: usize, - max_input_length: usize, - max_total_tokens: usize, - mut receiver: mpsc::UnboundedReceiver, -) { - let mut workers_senders = Vec::with_capacity(workers); - - // Create workers - for _ in 0..workers { - let tokenizer_clone: Option = tokenizer.clone().into(); - // Create channel to communicate with worker - let (worker_sender, worker_receiver) = mpsc::channel(workers); - workers_senders.push(worker_sender); - - // Spawn worker - tokio::task::spawn_blocking(move || { - validation_worker( - tokenizer_clone, - max_stop_sequences, - max_input_length, - max_total_tokens, - worker_receiver, - ) - }); - } - - loop { - // Load balance requests between workers - for sender in workers_senders.iter() { - if let Some(validation_request) = receiver.recv().await { - sender.send(validation_request).await.unwrap(); - } else { - return; - } - } - } -} - -/// Check the parameters inside the payload and get the number of tokens inside the input using -/// the tokenizer -fn validation_worker( - tokenizer: Option, - max_stop_sequences: usize, - max_input_length: usize, - max_total_tokens: usize, - mut receiver: mpsc::Receiver, -) { - // Seed rng - let mut rng = rand::thread_rng(); - +/// Start tokenization workers +fn tokenizer_worker(tokenizer: Tokenizer, receiver: flume::Receiver) { // Loop over requests - while let Some((request, response_tx, parent_span)) = receiver.blocking_recv() { + while let Ok(((inputs, truncate), response_tx, parent_span)) = receiver.recv() { parent_span.in_scope(|| { response_tx - .send( - validate( - request, - tokenizer.as_ref(), - max_stop_sequences, - max_input_length, - max_total_tokens, - &mut rng, - ) - .map_err(|err| { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); - tracing::error!("{err}"); - err - }), - ) + .send(prepare_input(inputs, truncate, &tokenizer).map_err(|err| { + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + tracing::error!("{err}"); + err + })) .unwrap_or(()) }) } } -fn validate( - request: GenerateRequest, - tokenizer: Option<&Tokenizer>, - max_stop_sequences: usize, - max_input_length: usize, - max_total_tokens: usize, - rng: &mut ThreadRng, -) -> Result { - let GenerateParameters { - best_of, - temperature, - repetition_penalty, - top_k, - top_p, - typical_p, - do_sample, - max_new_tokens, - stop: stop_sequences, - truncate, - seed, - watermark, - .. - } = request.parameters; +/// Get input length and optionally truncate it +fn prepare_input( + inputs: String, + truncate: Option, + tokenizer: &Tokenizer, +) -> Result<(String, usize), ValidationError> { + // Get the number of tokens in the input + let mut encoding = tokenizer + .encode(inputs.clone(), true) + .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; - // sampling must be true when best_of > 1 - let best_of = best_of.unwrap_or(1); - let sampling = do_sample - || temperature.is_some() - || top_k.is_some() - || top_p.is_some() - || typical_p.is_some(); - - if best_of > 1 && !sampling { - return Err(BestOfSampling); - } - - let temperature = temperature.unwrap_or(1.0); - if temperature <= 0.0 { - return Err(ValidationError::Temperature); - } - - let repetition_penalty = repetition_penalty.unwrap_or(1.0); - if repetition_penalty <= 0.0 { - return Err(ValidationError::RepetitionPenalty); - } - - // Different because the proto default value is not a valid value - // for the user - let top_p = top_p - .map(|value| { - if value <= 0.0 || value >= 1.0 { - return Err(ValidationError::TopP); - } - Ok(value) - }) - .unwrap_or(Ok(1.0))?; - - let typical_p = typical_p - .map(|value| { - if value <= 0.0 || value >= 1.0 { - return Err(ValidationError::TypicalP); - } - Ok(value) - }) - .unwrap_or(Ok(1.0))?; - - let top_k: u32 = top_k - .map(|value| { - if value <= 0 { - return Err(ValidationError::TopK); - } - Ok(value as u32) - }) - .unwrap_or(Ok(0))?; - - if max_new_tokens == 0 { - return Err(ValidationError::MaxNewTokens); - } - - if stop_sequences.len() > max_stop_sequences { - return Err(ValidationError::StopSequence( - max_stop_sequences, - stop_sequences.len(), - )); - } - - // If seed is None, assign a random one - let seed = match seed { - None => rng.gen(), - Some(seed) => { - if best_of > 1 { - return Err(BestOfSeed); - } - seed - } - }; - - // Check if inputs is empty - if request.inputs.is_empty() { - return Err(EmptyInput); - } - - // Check if truncate is strictly positive and less than max_input_length - let truncate = truncate - .map(|value| { - if value == 0 || value > max_input_length { - return Err(ValidationError::Truncate(max_input_length, value)); - } - Ok(Some(value)) - }) - .unwrap_or(Ok(None))?; - - // If we have a fast tokenizer - let inputs = if let Some(tokenizer) = tokenizer { - // Get the number of tokens in the input - let mut encoding = tokenizer - .encode(request.inputs.clone(), true) - .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; - - let (inputs, input_length) = if let Some(truncate) = truncate { + // Optionally truncate + let (inputs, input_length) = match truncate { + // Truncate is some and > encoding length + Some(truncate) if truncate > encoding.len() => { // truncate encoding and decode new inputs encoding.truncate(truncate, 0, TruncationDirection::Left); let inputs = tokenizer .decode(Vec::from(encoding.get_ids()), false) .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; (inputs, encoding.len()) - } else { - (request.inputs, encoding.len()) - }; - - if input_length > max_input_length { - return Err(ValidationError::InputLength(max_input_length, input_length)); } - - let total_tokens = input_length + max_new_tokens as usize; - if total_tokens > max_total_tokens { - return Err(ValidationError::MaxTotalTokens( - max_total_tokens, - input_length, - max_new_tokens, - )); - } - - metrics::histogram!("tgi_request_input_length", input_length as f64); - inputs - } else { - request.inputs + // Nothing to do + _ => (inputs, encoding.len()), }; - // Return ValidGenerateRequest - let parameters = NextTokenChooserParameters { - temperature, - repetition_penalty, - top_k, - top_p, - typical_p, - do_sample, - seed, - watermark, - }; - let stopping_parameters = StoppingCriteriaParameters { - max_new_tokens, - stop_sequences, - ignore_eos_token: false, - }; - - metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64); - - Ok(ValidGenerateRequest { - inputs, - truncate: truncate.unwrap_or(max_input_length) as u32, - parameters, - stopping_parameters, - }) + Ok((inputs, input_length)) } -type ValidationRequest = ( - GenerateRequest, - oneshot::Sender>, +type TokenizerRequest = ( + (String, Option), + oneshot::Sender>, Span, ); diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 2b8ef5f8..a9831cd7 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -24,6 +24,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", + truncate=100, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, ) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 76617b62..db68fc9c 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", + truncate=100, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, ) diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index 753ff5fc..8cf66d47 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="def", + truncate=100, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, ) @@ -30,6 +31,7 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="defworld", + truncate=100, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, ) diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 2d86c44b..baf44579 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", + truncate=100, parameters=default_pb_parameters, stopping_parameters=default_pb_stop_parameters, )