diff --git a/.devcontainer/Dockerfile.trtllm b/.devcontainer/Dockerfile.trtllm new file mode 100644 index 00000000..e69de29b diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 00000000..e69de29b diff --git a/.dockerignore b/.dockerignore index c69283ec..1c641e7a 100644 --- a/.dockerignore +++ b/.dockerignore @@ -2,3 +2,5 @@ aml target server/transformers server/flash-attention +cmake-build-debug/ +cmake-build-release/ diff --git a/.github/workflows/autodocs.yaml b/.github/workflows/autodocs.yaml index e10b232c..a768f263 100644 --- a/.github/workflows/autodocs.yaml +++ b/.github/workflows/autodocs.yaml @@ -28,7 +28,7 @@ jobs: - name: Install router id: install-router - run: cargo install --path router/ + run: cargo install --path backends/v3/ - uses: actions/setup-node@v4 with: @@ -41,5 +41,5 @@ jobs: - name: Check that documentation is up-to-date run: | - npm install -g swagger-cli + npm install -g @redocly/cli python update_doc.py --check diff --git a/.gitignore b/.gitignore index e9ad1808..0de8b848 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,10 @@ target router/tokenizer.json *__pycache__* +backends/v3/src/client/pb +backends/client/src/v2/pb +backends/client/src/v3/pb + # ROCm auto-generated files *.hip server/exllamav2_kernels/exllamav2_kernels/hip/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bb3879b2..6f5e685e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,8 +13,8 @@ repos: - repo: https://github.com/doublify/pre-commit-rust rev: v1.0 hooks: - - id: fmt - id: cargo-check + - id: fmt - id: clippy - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.3.0 diff --git a/.redocly.lint-ignore.yaml b/.redocly.lint-ignore.yaml new file mode 100644 index 00000000..382c9ab6 --- /dev/null +++ b/.redocly.lint-ignore.yaml @@ -0,0 +1,79 @@ +# This file instructs Redocly's linter to ignore the rules contained for specific parts of your API. +# See https://redoc.ly/docs/cli/ for more information. +docs/openapi.json: + no-empty-servers: + - '#/openapi' + spec: + - >- + #/components/schemas/GenerateParameters/properties/best_of/exclusiveMinimum + - >- + #/components/schemas/GenerateParameters/properties/frequency_penalty/exclusiveMinimum + - '#/components/schemas/GenerateParameters/properties/grammar/nullable' + - >- + #/components/schemas/GenerateParameters/properties/repetition_penalty/exclusiveMinimum + - '#/components/schemas/GenerateParameters/properties/seed/exclusiveMinimum' + - >- + #/components/schemas/GenerateParameters/properties/temperature/exclusiveMinimum + - '#/components/schemas/GenerateParameters/properties/top_k/exclusiveMinimum' + - >- + #/components/schemas/GenerateParameters/properties/top_n_tokens/exclusiveMinimum + - '#/components/schemas/GenerateParameters/properties/top_p/exclusiveMinimum' + - >- + #/components/schemas/GenerateParameters/properties/typical_p/exclusiveMinimum + - '#/components/schemas/GenerateResponse/properties/details/nullable' + - '#/components/schemas/StreamResponse/properties/details/nullable' + - '#/components/schemas/ChatRequest/properties/response_format/nullable' + - '#/components/schemas/ChatRequest/properties/tool_choice/nullable' + - '#/components/schemas/ToolChoice/nullable' + - '#/components/schemas/ChatCompletionComplete/properties/logprobs/nullable' + - '#/components/schemas/ChatCompletionChoice/properties/logprobs/nullable' + no-invalid-media-type-examples: + - '#/paths/~1/post/responses/422/content/application~1json/example' + - '#/paths/~1/post/responses/424/content/application~1json/example' + - '#/paths/~1/post/responses/429/content/application~1json/example' + - '#/paths/~1/post/responses/500/content/application~1json/example' + - '#/paths/~1generate/post/responses/422/content/application~1json/example' + - '#/paths/~1generate/post/responses/424/content/application~1json/example' + - '#/paths/~1generate/post/responses/429/content/application~1json/example' + - '#/paths/~1generate/post/responses/500/content/application~1json/example' + - >- + #/paths/~1generate_stream/post/responses/422/content/text~1event-stream/example + - >- + #/paths/~1generate_stream/post/responses/424/content/text~1event-stream/example + - >- + #/paths/~1generate_stream/post/responses/429/content/text~1event-stream/example + - >- + #/paths/~1generate_stream/post/responses/500/content/text~1event-stream/example + - '#/paths/~1tokenize/post/responses/404/content/application~1json/example' + - >- + #/paths/~1v1~1chat~1completions/post/responses/422/content/application~1json/example + - >- + #/paths/~1v1~1chat~1completions/post/responses/424/content/application~1json/example + - >- + #/paths/~1v1~1chat~1completions/post/responses/429/content/application~1json/example + - >- + #/paths/~1v1~1chat~1completions/post/responses/500/content/application~1json/example + - >- + #/paths/~1v1~1completions/post/responses/422/content/application~1json/example + - >- + #/paths/~1v1~1completions/post/responses/424/content/application~1json/example + - >- + #/paths/~1v1~1completions/post/responses/429/content/application~1json/example + - >- + #/paths/~1v1~1completions/post/responses/500/content/application~1json/example + operation-4xx-response: + - '#/paths/~1health/get/responses' + - '#/paths/~1info/get/responses' + - '#/paths/~1metrics/get/responses' + no-unused-components: + - '#/components/schemas/Completion' + security-defined: + - '#/paths/~1/post' + - '#/paths/~1generate/post' + - '#/paths/~1generate_stream/post' + - '#/paths/~1health/get' + - '#/paths/~1info/get' + - '#/paths/~1metrics/get' + - '#/paths/~1tokenize/post' + - '#/paths/~1v1~1chat~1completions/post' + - '#/paths/~1v1~1completions/post' diff --git a/Cargo.lock b/Cargo.lock index 3522fc3f..92367d1e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -28,7 +28,7 @@ dependencies = [ "once_cell", "serde", "version_check", - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -48,9 +48,9 @@ checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1" [[package]] name = "anstream" -version = "0.6.14" +version = "0.6.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "418c75fa768af9c03be99d17643f93f79bbba589895012a80e3452a19ddda15b" +checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526" dependencies = [ "anstyle", "anstyle-parse", @@ -63,33 +63,33 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.7" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b" +checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" [[package]] name = "anstyle-parse" -version = "0.2.4" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c03a11a9034d92058ceb6ee011ce58af4a9bf61491aa7e1e59ecd24bd40d22d4" +checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.1.0" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad186efb764318d35165f1758e7dcef3b10628e26d41a44bc5550652e6804391" +checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a" dependencies = [ "windows-sys 0.52.0", ] [[package]] name = "anstyle-wincon" -version = "3.0.3" +version = "3.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61a38449feb7068f52bb06c12759005cf459ee52bb4adc1d5a7c4322d716fb19" +checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8" dependencies = [ "anstyle", "windows-sys 0.52.0", @@ -121,7 +121,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -160,18 +160,18 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] name = "async-trait" -version = "0.1.80" +version = "0.1.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" +checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -234,9 +234,9 @@ dependencies = [ [[package]] name = "aws-lc-rs" -version = "1.7.3" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf7d844e282b4b56750b2d4e893b2205581ded8709fddd2b6aa5418c150ca877" +checksum = "4ae74d9bd0a7530e8afd1770739ad34b36838829d6ad61818f9230f683f5ad77" dependencies = [ "aws-lc-sys", "mirai-annotations", @@ -246,9 +246,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.18.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3a2c29203f6bf296d01141cc8bb9dbd5ecd4c27843f2ee0767bcd5985a927da" +checksum = "2e89b6941c2d1a7045538884d6e760ccfffdf8e1ffc2613d8efa74305e1f3752" dependencies = [ "bindgen", "cc", @@ -272,7 +272,7 @@ dependencies = [ "futures-util", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.29", + "hyper 0.14.30", "itoa", "matchit", "memchr", @@ -302,9 +302,9 @@ dependencies = [ "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", - "hyper 1.3.1", + "hyper 1.4.1", "hyper-util", "itoa", "matchit", @@ -352,7 +352,7 @@ dependencies = [ "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", "mime", "pin-project-lite", @@ -433,7 +433,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.68", + "syn 2.0.72", "which", ] @@ -472,9 +472,9 @@ checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" [[package]] name = "bitstream-io" -version = "2.4.2" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "415f8399438eb5e4b2f73ed3152a3448b98149dda642a957ee704e1daa5cf1d8" +checksum = "3dcde5f311c85b8ca30c2e4198d4326bc342c76541590106f5fa4a50946ea499" [[package]] name = "block-buffer" @@ -487,9 +487,9 @@ dependencies = [ [[package]] name = "built" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6a6c0b39c38fd754ac338b00a88066436389c0f029da5d37d1e01091d9b7c17" +checksum = "236e6289eda5a812bc6b53c3b024039382a2895fbbeef2d748b2931546d392c4" [[package]] name = "bumpalo" @@ -523,9 +523,9 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" [[package]] name = "bytes" -version = "1.6.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" +checksum = "a12916984aab3fa6e39d655a33e09c0071eb36d6ab3aea5c2d78551f1df6d952" [[package]] name = "camino" @@ -567,13 +567,12 @@ checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" [[package]] name = "cc" -version = "1.0.101" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac367972e516d45567c7eafc73d24e1c193dcf200a8d94e9db7b3d38b349572d" +checksum = "26a5c3fd7bfa1ce3897a3a3501d362b2d87b7f2583ebcb4a949ec25911025cbc" dependencies = [ "jobserver", "libc", - "once_cell", ] [[package]] @@ -620,9 +619,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.7" +version = "4.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5db83dced34638ad474f39f250d7fea9598bdd239eaced1bdf45d597da0f433f" +checksum = "35723e6a11662c2afb578bcf0b88bf6ea8e21282a953428f240574fcc3a2b5b3" dependencies = [ "clap_builder", "clap_derive", @@ -630,9 +629,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.7" +version = "4.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7e204572485eb3fbf28f871612191521df159bc3e15a9f5064c66dba3a8c05f" +checksum = "49eb96cbfa7cfa35017b7cd548c75b14c3118c98b423041d70562665e07fb0fa" dependencies = [ "anstream", "anstyle", @@ -642,21 +641,21 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.5" +version = "4.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c780290ccf4fb26629baa7a1081e68ced113f1d3ec302fa5948f1c381ebf06c6" +checksum = "5d029b67f89d30bbb547c89fd5161293c0aec155fc691d7924b64550662db93e" dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] name = "clap_lex" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70" +checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" [[package]] name = "cmake" @@ -667,6 +666,16 @@ dependencies = [ "cc", ] +[[package]] +name = "codespan-reporting" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e" +dependencies = [ + "termcolor", + "unicode-width", +] + [[package]] name = "color_quant" version = "1.1.0" @@ -675,9 +684,9 @@ checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" [[package]] name = "colorchoice" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" +checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" [[package]] name = "console" @@ -769,7 +778,7 @@ dependencies = [ "bitflags 2.6.0", "crossterm_winapi", "libc", - "mio", + "mio 0.8.11", "parking_lot", "signal-hook", "signal-hook-mio", @@ -833,10 +842,54 @@ dependencies = [ ] [[package]] -name = "darling" -version = "0.20.9" +name = "cxx" +version = "1.0.124" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83b2eb4d90d12bdda5ed17de686c2acb4c57914f8f921b8da7e112b5a36f3fe1" +checksum = "273dcfd3acd4e1e276af13ed2a43eea7001318823e7a726a6b3ed39b4acc0b82" +dependencies = [ + "cc", + "cxxbridge-flags", + "cxxbridge-macro", + "link-cplusplus", +] + +[[package]] +name = "cxx-build" +version = "1.0.124" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8b2766fbd92be34e9ed143898fce6c572dc009de39506ed6903e5a05b68914e" +dependencies = [ + "cc", + "codespan-reporting", + "once_cell", + "proc-macro2", + "quote", + "scratch", + "syn 2.0.72", +] + +[[package]] +name = "cxxbridge-flags" +version = "1.0.124" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "839fcd5e43464614ffaa989eaf1c139ef1f0c51672a1ed08023307fa1b909ccd" + +[[package]] +name = "cxxbridge-macro" +version = "1.0.124" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b2c1c1776b986979be68bb2285da855f8d8a35851a769fca8740df7c3d07877" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", +] + +[[package]] +name = "darling" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" dependencies = [ "darling_core", "darling_macro", @@ -844,27 +897,27 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.20.9" +version = "0.20.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "622687fe0bac72a04e5599029151f5796111b90f1baaa9b544d807a5e31cd120" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" dependencies = [ "fnv", "ident_case", "proc-macro2", "quote", "strsim", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] name = "darling_macro" -version = "0.20.9" +version = "0.20.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "733cabb43482b1a1b53eee8583c2b9e8684d592215ea83efd305dd31bc2f0178" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -894,7 +947,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -904,7 +957,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b" dependencies = [ "derive_builder_core", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -1179,7 +1232,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -1433,9 +1486,9 @@ dependencies = [ [[package]] name = "http-body" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", "http 1.1.0", @@ -1450,7 +1503,7 @@ dependencies = [ "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "pin-project-lite", ] @@ -1468,9 +1521,9 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "hyper" -version = "0.14.29" +version = "0.14.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f361cde2f109281a220d4307746cdfd5ee3f410da58a70377762396775634b33" +checksum = "a152ddd61dfaec7273fe8419ab357f33aee0d914c5f4efbf0d96fa749eea5ec9" dependencies = [ "bytes", "futures-channel", @@ -1492,16 +1545,16 @@ dependencies = [ [[package]] name = "hyper" -version = "1.3.1" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe575dd17d0862a9a33781c8c4696a55c320909004a67a00fb286ba8b1bc496d" +checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" dependencies = [ "bytes", "futures-channel", "futures-util", "h2 0.4.5", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "httparse", "httpdate", "itoa", @@ -1519,10 +1572,10 @@ checksum = "5ee4be2c948921a1a5320b629c4193916ed787a7f7f293fd3f7f5a6c9de74155" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.3.1", + "hyper 1.4.1", "hyper-util", "log", - "rustls 0.23.10", + "rustls 0.23.12", "rustls-native-certs", "rustls-pki-types", "tokio", @@ -1536,7 +1589,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" dependencies = [ - "hyper 0.14.29", + "hyper 0.14.30", "pin-project-lite", "tokio", "tokio-io-timeout", @@ -1549,7 +1602,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ "bytes", - "hyper 0.14.29", + "hyper 0.14.30", "native-tls", "tokio", "tokio-native-tls", @@ -1557,16 +1610,16 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b875924a60b96e5d7b9ae7b066540b1dd1cbd90d1828f54c92e02a283351c56" +checksum = "3ab92f4f49ee4fb4f997c784b7a2e0fa70050211e0b6a287f898c3c9785ca956" dependencies = [ "bytes", "futures-channel", "futures-util", "http 1.1.0", - "http-body 1.0.0", - "hyper 1.3.1", + "http-body 1.0.1", + "hyper 1.4.1", "pin-project-lite", "socket2", "tokio", @@ -1593,12 +1646,12 @@ dependencies = [ [[package]] name = "image" -version = "0.25.1" +version = "0.25.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd54d660e773627692c524beaad361aca785a4f9f5730ce91f42aabe5bce3d11" +checksum = "99314c8a2152b8ddb211f924cdae532d8c5e4c8bb54728e12fff1b0cd5963a10" dependencies = [ "bytemuck", - "byteorder", + "byteorder-lite", "color_quant", "exr", "gif", @@ -1616,12 +1669,12 @@ dependencies = [ [[package]] name = "image-webp" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d730b085583c4d789dfd07fdcf185be59501666a90c97c40162b37e4fdad272d" +checksum = "f79afb8cbee2ef20f59ccd477a218c12a93943d075b492015ecb1bb81f8ee904" dependencies = [ "byteorder-lite", - "thiserror", + "quick-error", ] [[package]] @@ -1700,7 +1753,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -1711,9 +1764,9 @@ checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" [[package]] name = "is_terminal_polyfill" -version = "1.70.0" +version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" [[package]] name = "iso8601" @@ -1759,9 +1812,9 @@ checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "jobserver" -version = "0.1.31" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2b099aaa34a9751c5bf0878add70444e1ed2dd73f347be99003d4577277de6e" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" dependencies = [ "libc", ] @@ -1848,12 +1901,12 @@ dependencies = [ [[package]] name = "libloading" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e310b3a6b5907f99202fcdb4960ff45b93735d7c7d96b760fcff8db2dc0e103d" +checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" dependencies = [ "cfg-if", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -1872,6 +1925,15 @@ dependencies = [ "libc", ] +[[package]] +name = "link-cplusplus" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d240c6f7e1ba3a28b0249f774e6a9dd0175054b52dfbb61b16eb8505c3785c9" +dependencies = [ + "cc", +] + [[package]] name = "linux-raw-sys" version = "0.4.14" @@ -1890,9 +1952,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.21" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "loop9" @@ -1947,7 +2009,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ea1f30cedd69f0a2954655f7188c6a834246d2bcf1e315e2ac40c4b24dc9519" dependencies = [ "cfg-if", - "rayon", ] [[package]] @@ -1968,13 +2029,13 @@ dependencies = [ [[package]] name = "metrics-exporter-prometheus" -version = "0.15.1" +version = "0.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf0af7a0d7ced10c0151f870e5e3f3f8bc9ffc5992d32873566ca1f9169ae776" +checksum = "b4f0c8427b39666bf970460908b213ec09b3b350f20c0c2eabcbba51704a08e6" dependencies = [ "base64 0.22.1", "http-body-util", - "hyper 1.3.1", + "hyper 1.4.1", "hyper-rustls", "hyper-util", "indexmap 2.2.6", @@ -2010,9 +2071,9 @@ checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] name = "mime_guess" -version = "2.0.4" +version = "2.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" dependencies = [ "mime", "unicase", @@ -2020,18 +2081,18 @@ dependencies = [ [[package]] name = "minijinja" -version = "2.0.2" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e136ef580d7955019ab0a407b68d77c292a9976907e217900f3f76bc8f6dc1a4" +checksum = "45f7e8e35b6c7b169bf40b0176d2c79291ab8ee53290b84e0668ab21d841aa9d" dependencies = [ "serde", ] [[package]] name = "minijinja-contrib" -version = "2.0.2" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15ee37078c98d31e510d6a7af488031a2c3ccacdb76c5c4fc98ddfe6d0e9da07" +checksum = "6853ef2340c668281c5ea86b04da2ebb2fc9e98a7185a887591de4cac945d5b5" dependencies = [ "minijinja", "serde", @@ -2065,6 +2126,18 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "mio" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" +dependencies = [ + "hermit-abi", + "libc", + "wasi", + "windows-sys 0.52.0", +] + [[package]] name = "mirai-annotations" version = "1.12.0" @@ -2089,7 +2162,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -2155,7 +2228,7 @@ dependencies = [ "bytes", "futures", "hostname", - "hyper 0.14.29", + "hyper 0.14.30", "muxado", "once_cell", "parking_lot", @@ -2240,9 +2313,9 @@ dependencies = [ [[package]] name = "num-bigint" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c165a9ab64cf766f73521c0dd2cfdff64f488b8f0b3e621face3462d3db536d7" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" dependencies = [ "num-integer", "num-traits", @@ -2277,7 +2350,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -2348,9 +2421,9 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" [[package]] name = "object" -version = "0.36.0" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "576dfe1fc8f9df304abb159d767a29d0476f7750fbf8aa7ad07816004a207434" +checksum = "3f203fa8daa7bb185f760ae12bd8e097f63d17041dcdcaf675ac54cdf863170e" dependencies = [ "memchr", ] @@ -2385,9 +2458,9 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.64" +version = "0.10.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" +checksum = "9529f4786b70a3e8c61e11179af17ab6188ad8d0ded78c5529441ed39d4bd9c1" dependencies = [ "bitflags 2.6.0", "cfg-if", @@ -2406,7 +2479,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -2417,9 +2490,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.102" +version = "0.9.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c597637d56fbc83893a35eb0dd04b2b8e7a50c91e64e9493e398b5df4fb45fa2" +checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6" dependencies = [ "cc", "libc", @@ -2453,6 +2526,20 @@ dependencies = [ "urlencoding", ] +[[package]] +name = "opentelemetry" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b69a91d4893e713e06f724597ad630f1fa76057a5e1026c0ca67054a9032a76" +dependencies = [ + "futures-core", + "futures-sink", + "js-sys", + "once_cell", + "pin-project-lite", + "thiserror", +] + [[package]] name = "opentelemetry-otlp" version = "0.13.0" @@ -2546,7 +2633,27 @@ dependencies = [ "glob", "once_cell", "opentelemetry 0.21.0", - "ordered-float 4.2.0", + "ordered-float 4.2.2", + "percent-encoding", + "rand", + "thiserror", +] + +[[package]] +name = "opentelemetry_sdk" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae312d58eaa90a82d2e627fd86e075cf5230b3f11794e2ed74199ebbe572d4fd" +dependencies = [ + "async-trait", + "futures-channel", + "futures-executor", + "futures-util", + "glob", + "lazy_static", + "once_cell", + "opentelemetry 0.23.0", + "ordered-float 4.2.2", "percent-encoding", "rand", "thiserror", @@ -2569,9 +2676,9 @@ dependencies = [ [[package]] name = "ordered-float" -version = "4.2.0" +version = "4.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a76df7075c7d4d01fdcb46c912dd17fba5b60c78ea480b475f2b6ab6f666584e" +checksum = "4a91171844676f8c7990ce64959210cd2eaef32c2612c50f9fae9f8aaa6065a6" dependencies = [ "num-traits", ] @@ -2613,7 +2720,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -2655,7 +2762,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -2691,9 +2798,9 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" +checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" [[package]] name = "powerfmt" @@ -2703,9 +2810,12 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" -version = "0.2.17" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +checksum = "dee4364d9f3b902ef14fab8a1ddffb783a1cb6b4bba3bfc1fa3922732c7de97f" +dependencies = [ + "zerocopy 0.6.6", +] [[package]] name = "prettyplease" @@ -2714,7 +2824,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e" dependencies = [ "proc-macro2", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -2766,7 +2876,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd" dependencies = [ "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -2806,7 +2916,7 @@ dependencies = [ "prost 0.12.6", "prost-types", "regex", - "syn 2.0.68", + "syn 2.0.72", "tempfile", ] @@ -2833,7 +2943,7 @@ dependencies = [ "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -2968,24 +3078,23 @@ dependencies = [ [[package]] name = "ravif" -version = "0.11.7" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67376f469e7e7840d0040bbf4b9b3334005bb167f814621326e4c7ab8cd6e944" +checksum = "5797d09f9bd33604689e87e8380df4951d4912f01b63f71205e2abd4ae25e6b6" dependencies = [ "avif-serialize", "imgref", "loop9", "quick-error", "rav1e", - "rayon", "rgb", ] [[package]] name = "raw-cpuid" -version = "11.0.2" +version = "11.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e29830cbb1290e404f24c73af91c5d8d631ce7e128691e9477556b540cd01ecd" +checksum = "cb9ee317cfe3fbd54b36a511efc1edd42e216903c9cd575e686dd68a2ba90d8d" dependencies = [ "bitflags 2.6.0", ] @@ -3023,9 +3132,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c82cf8cff14456045f55ec4241383baeff27af886adb72ffb2162f99911de0fd" +checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4" dependencies = [ "bitflags 2.6.0", ] @@ -3099,7 +3208,7 @@ dependencies = [ "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.29", + "hyper 0.14.30", "hyper-tls", "ipnet", "js-sys", @@ -3127,9 +3236,9 @@ dependencies = [ [[package]] name = "rgb" -version = "0.8.37" +version = "0.8.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05aaa8004b64fd573fc9d002f4e632d51ad4f026c2b5ba95fcb6c2f32c2c47d8" +checksum = "ade4539f42266ded9e755c605bdddf546242b2c961b03b06a7375260788a0523" dependencies = [ "bytemuck", ] @@ -3166,9 +3275,9 @@ dependencies = [ [[package]] name = "rust-embed" -version = "8.4.0" +version = "8.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19549741604902eb99a7ed0ee177a0663ee1eda51a29f71401f166e47e77806a" +checksum = "fa66af4a4fdd5e7ebc276f115e895611a34739a9c1c01028383d612d550953c0" dependencies = [ "rust-embed-impl", "rust-embed-utils", @@ -3177,22 +3286,22 @@ dependencies = [ [[package]] name = "rust-embed-impl" -version = "8.4.0" +version = "8.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb9f96e283ec64401f30d3df8ee2aaeb2561f34c824381efa24a35f79bf40ee4" +checksum = "6125dbc8867951125eec87294137f4e9c2c96566e61bf72c45095a7c77761478" dependencies = [ "proc-macro2", "quote", "rust-embed-utils", - "syn 2.0.68", + "syn 2.0.72", "walkdir", ] [[package]] name = "rust-embed-utils" -version = "8.4.0" +version = "8.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38c74a686185620830701348de757fd36bef4aa9680fd23c49fc539ddcc1af32" +checksum = "2e5347777e9aacb56039b0e1f28785929a8a3b709e87482e7442c72e7c12529d" dependencies = [ "sha2", "walkdir", @@ -3260,9 +3369,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.10" +version = "0.23.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05cff451f60db80f490f3c182b77c35260baace73209e9cdbbe526bfe3a4d402" +checksum = "c58f8c84392efc0a126acce10fa59ff7b3d2ac06ab451a33f2741989b806b044" dependencies = [ "aws-lc-rs", "log", @@ -3275,9 +3384,9 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f1fb85efa936c42c6d5fc28d2629bb51e4b2f4b8a5211e297d599cc5a093792" +checksum = "a88d6d420651b496bdd98684116959239430022a115c1240e6c3993be0b15fba" dependencies = [ "openssl-probe", "rustls-pemfile 2.1.2", @@ -3313,9 +3422,9 @@ checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" [[package]] name = "rustls-webpki" -version = "0.102.4" +version = "0.102.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff448f7e92e913c4b7d4c6d8e4540a1724b319b4152b8aef6d4cf8339712b33e" +checksum = "8e6b52d4fda176fd835fdc55a835d4a89b8499cad995885a21149d5ad62f852e" dependencies = [ "aws-lc-rs", "ring 0.17.8", @@ -3359,6 +3468,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "scratch" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3cf7c11c38cb994f3d40e8a8cde3bbd1f72a435e4c49e85d6553d8312306152" + [[package]] name = "sct" version = "0.7.1" @@ -3371,9 +3486,9 @@ dependencies = [ [[package]] name = "security-framework" -version = "2.11.0" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags 2.6.0", "core-foundation", @@ -3384,9 +3499,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.11.0" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "317936bbbd05227752583946b9e66d7ce3b489f84e11a94a510b4437fef407d7" +checksum = "75da29fe9b9b08fe9d6b22b5b4bcbc75d8db3aa31e639aa56bb62e9d46bfceaf" dependencies = [ "core-foundation-sys", "libc", @@ -3403,31 +3518,32 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.203" +version = "1.0.204" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" +checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.203" +version = "1.0.204" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" +checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] name = "serde_json" -version = "1.0.120" +version = "1.0.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5" +checksum = "4ab380d7d9f22ef3f21ad3e6c1ebe8e4fc7a2000ccba2e4d71fc96f15b2cb609" dependencies = [ "itoa", + "memchr", "ryu", "serde", ] @@ -3444,9 +3560,9 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "0.6.6" +version = "0.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79e674e01f999af37c49f70a6ede167a8a60b2503e56c5599532a65baa5969a0" +checksum = "eb5b1b31579f3811bf615c144393417496f152e12ac8b7663bf664f4a815306d" dependencies = [ "serde", ] @@ -3501,12 +3617,12 @@ dependencies = [ [[package]] name = "signal-hook-mio" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29ad2e15f37ec9a6cc544097b78a1ec90001e9f71b81338ca39f430adaca99af" +checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd" dependencies = [ "libc", - "mio", + "mio 0.8.11", "signal-hook", ] @@ -3626,7 +3742,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -3648,9 +3764,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.68" +version = "2.0.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "901fa70d88b9d6c98022e23b4136f9f3e54e4662c3bc1bd1d84a42a9a0f0c1e9" +checksum = "dc4b9b9bf2add8093d3f2c0204471e951b2285580335de42f9d2534f3ae7a8af" dependencies = [ "proc-macro2", "quote", @@ -3744,9 +3860,9 @@ dependencies = [ [[package]] name = "target-lexicon" -version = "0.12.14" +version = "0.12.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" +checksum = "4873307b7c257eddcb50c9bedf158eb669578359fb28428bef438fec8e6ba7c2" [[package]] name = "tempfile" @@ -3760,6 +3876,37 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "text-generation-backends-trtllm" +version = "2.2.1-dev0" +dependencies = [ + "async-stream", + "async-trait", + "clap", + "cmake", + "cxx", + "cxx-build", + "log", + "pkg-config", + "text-generation-router", + "thiserror", + "tokenizers", + "tokio", + "tokio-stream", + "tracing", + "tracing-opentelemetry 0.24.0", + "tracing-subscriber", +] + [[package]] name = "text-generation-benchmark" version = "2.2.1-dev0" @@ -3823,6 +3970,7 @@ name = "text-generation-router" version = "2.2.1-dev0" dependencies = [ "async-stream", + "async-trait", "axum 0.7.5", "axum-tracing-opentelemetry", "base64 0.22.1", @@ -3850,7 +3998,6 @@ dependencies = [ "serde", "serde_json", "sysinfo", - "text-generation-client", "thiserror", "tokenizers", "tokio", @@ -3859,30 +4006,79 @@ dependencies = [ "tracing", "tracing-opentelemetry 0.21.0", "tracing-subscriber", + "ureq", "utoipa", "utoipa-swagger-ui", "uuid", "vergen", ] +[[package]] +name = "text-generation-router-v3" +version = "2.2.1-dev0" +dependencies = [ + "async-stream", + "async-trait", + "axum 0.7.5", + "axum-tracing-opentelemetry", + "base64 0.22.1", + "clap", + "futures", + "futures-util", + "grpc-metadata", + "hf-hub", + "image", + "init-tracing-opentelemetry", + "jsonschema", + "metrics", + "metrics-exporter-prometheus", + "minijinja", + "minijinja-contrib", + "nohash-hasher", + "once_cell", + "opentelemetry 0.20.0", + "opentelemetry-otlp", + "prost 0.12.6", + "prost-build", + "rand", + "regex", + "reqwest", + "serde", + "serde_json", + "text-generation-router", + "thiserror", + "tokenizers", + "tokio", + "tokio-stream", + "tonic 0.10.2", + "tonic-build", + "tower", + "tower-http", + "tracing", + "tracing-opentelemetry 0.21.0", + "tracing-subscriber", + "utoipa", + "utoipa-swagger-ui", +] + [[package]] name = "thiserror" -version = "1.0.61" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" +checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.61" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" +checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -3941,9 +4137,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.6.1" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c55115c6fbe2d2bef26eb09ad74bde02d8255476fc0c7b515ef09fbb35742d82" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" dependencies = [ "tinyvec_macros", ] @@ -3989,21 +4185,20 @@ dependencies = [ [[package]] name = "tokio" -version = "1.38.0" +version = "1.39.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba4f4a02a7a80d6f274636f0aa95c7e383b912d41fe721a31f29e29698585a4a" +checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1" dependencies = [ "backtrace", "bytes", "libc", - "mio", - "num_cpus", + "mio 1.0.1", "parking_lot", "pin-project-lite", "signal-hook-registry", "socket2", "tokio-macros", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -4018,13 +4213,13 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.3.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" +checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -4054,7 +4249,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.10", + "rustls 0.23.12", "rustls-pki-types", "tokio", ] @@ -4086,9 +4281,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.14" +version = "0.8.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f49eb2ab21d2f26bd6db7bf383edc527a7ebaee412d17af4d40fdccd442f335" +checksum = "81967dd0dd2c1ab0bc3468bd7caecc32b8a4aa47d0c8c695d8c2b2108168d62c" dependencies = [ "serde", "serde_spanned", @@ -4098,18 +4293,18 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.6" +version = "0.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4badfd56924ae69bcc9039335b2e017639ce3f9b001c393c1b2d1ef846ce2cbf" +checksum = "f8fb9f64314842840f1d940ac544da178732128f1c78c21772e876579e0da1db" dependencies = [ "serde", ] [[package]] name = "toml_edit" -version = "0.22.14" +version = "0.22.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f21c7aaf97f1bd9ca9d4f9e73b0a6c74bd5afef56f2bc931943a6e1c37e04e38" +checksum = "8d9f8729f5aea9562aac1cc0441f5d6de3cff1ee0c5d67293eeca5eb36ee7c16" dependencies = [ "indexmap 2.2.6", "serde", @@ -4133,7 +4328,7 @@ dependencies = [ "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.29", + "hyper 0.14.30", "hyper-timeout", "percent-encoding", "pin-project", @@ -4160,7 +4355,7 @@ dependencies = [ "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.29", + "hyper 0.14.30", "hyper-timeout", "percent-encoding", "pin-project", @@ -4183,7 +4378,7 @@ dependencies = [ "proc-macro2", "prost-build", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -4215,7 +4410,7 @@ dependencies = [ "bitflags 2.6.0", "bytes", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", "pin-project-lite", "tower-layer", @@ -4254,7 +4449,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -4320,7 +4515,25 @@ dependencies = [ "tracing-core", "tracing-log 0.2.0", "tracing-subscriber", - "web-time", + "web-time 0.2.4", +] + +[[package]] +name = "tracing-opentelemetry" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f68803492bf28ab40aeccaecc7021096bd256baf7ca77c3d425d89b35a7be4e4" +dependencies = [ + "js-sys", + "once_cell", + "opentelemetry 0.23.0", + "opentelemetry_sdk 0.23.0", + "smallvec", + "tracing", + "tracing-core", + "tracing-log 0.2.0", + "tracing-subscriber", + "web-time 1.1.0", ] [[package]] @@ -4512,7 +4725,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -4550,7 +4763,7 @@ checksum = "ee1cd046f83ea2c4e920d6ee9f7c3537ef928d75dce5d84a87c2c5d6b3999a3a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -4578,9 +4791,9 @@ checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" [[package]] name = "vergen" -version = "8.3.1" +version = "8.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e27d6bdd219887a9eadd19e1c34f32e47fa332301184935c6d9bca26f3cca525" +checksum = "2990d9ea5967266ea0ccf413a4aa5c42a93dbcfda9cb49a97de6931726b12566" dependencies = [ "anyhow", "cargo_metadata", @@ -4600,9 +4813,9 @@ checksum = "852e951cb7832cb45cb1169900d19760cfa39b82bc0ea9c0e5a14ae88411c98b" [[package]] name = "version_check" -version = "0.9.4" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" [[package]] name = "walkdir" @@ -4650,7 +4863,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", "wasm-bindgen-shared", ] @@ -4684,7 +4897,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -4715,6 +4928,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki" version = "0.22.4" @@ -4790,7 +5013,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" dependencies = [ "windows-core", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -4799,7 +5022,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -4826,7 +5049,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -4861,18 +5084,18 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm 0.52.5", - "windows_aarch64_msvc 0.52.5", - "windows_i686_gnu 0.52.5", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", "windows_i686_gnullvm", - "windows_i686_msvc 0.52.5", - "windows_x86_64_gnu 0.52.5", - "windows_x86_64_gnullvm 0.52.5", - "windows_x86_64_msvc 0.52.5", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", ] [[package]] @@ -4889,9 +5112,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_msvc" @@ -4907,9 +5130,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_i686_gnu" @@ -4925,15 +5148,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_msvc" @@ -4949,9 +5172,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_x86_64_gnu" @@ -4967,9 +5190,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnullvm" @@ -4985,9 +5208,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_msvc" @@ -5003,15 +5226,15 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.13" +version = "0.6.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59b5e5f6c299a3c7890b876a2a587f3115162487e704907d9b6cd29473052ba1" +checksum = "b480ae9340fc261e6be3e95a1ba86d54ae3f9171132a73ce8d4bbaf68339507c" dependencies = [ "memchr", ] @@ -5028,22 +5251,43 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.34" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087" +checksum = "854e949ac82d619ee9a14c66a1b674ac730422372ccb759ce0c39cabcf2bf8e6" dependencies = [ - "zerocopy-derive", + "byteorder", + "zerocopy-derive 0.6.6", +] + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "zerocopy-derive 0.7.35", ] [[package]] name = "zerocopy-derive" -version = "0.7.34" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" +checksum = "125139de3f6b9d625c39e2efdd73d41bdac468ccd556556440e322be0e1bbd91" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", ] [[package]] @@ -5063,7 +5307,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.72", ] [[package]] @@ -5095,9 +5339,9 @@ dependencies = [ [[package]] name = "zune-jpeg" -version = "0.4.11" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec866b44a2a1fd6133d363f073ca1b179f438f99e7e5bfb1e33f7181facfe448" +checksum = "16099418600b4d8f028622f73ff6e3deaabdff330fb9a2a131dea781ee8b0768" dependencies = [ "zune-core", ] diff --git a/Cargo.toml b/Cargo.toml index fbda40ba..8bf75b90 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,10 +1,19 @@ [workspace] members = [ - "benchmark", - "router", - "router/client", - "router/grpc-metadata", - "launcher" + "benchmark", + "backends/v3", + "backends/grpc-metadata", + "backends/trtllm", + "backends/client", + "launcher" +] +default-members = [ + "benchmark", + "backends/v3", + "backends/grpc-metadata", + # "backends/trtllm", + "backends/client", + "launcher" ] resolver = "2" @@ -18,6 +27,8 @@ homepage = "https://github.com/huggingface/text-generation-inference" base64 = "0.22.0" tokenizers = { version = "0.19.1", features = ["http"] } hf-hub = { version = "0.3.1", features = ["tokio"] } +metrics = { version = "0.23.0" } +metrics-exporter-prometheus = { version = "0.15.1", features = [] } [profile.release] incremental = true diff --git a/Dockerfile b/Dockerfile index 52393a76..0d57e38d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,6 +11,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo chef prepare --recipe-path recipe.json @@ -33,6 +34,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo build --profile release-opt diff --git a/Dockerfile.trtllm b/Dockerfile.trtllm new file mode 100644 index 00000000..4543ae80 --- /dev/null +++ b/Dockerfile.trtllm @@ -0,0 +1,23 @@ +# All the tooling for CUDA +FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 AS cuda-builder + +WORKDIR /usr/src/tgi/backends/trtllm +RUN apt update && apt install -y cmake git git-lfs gcc g++ ninja-build libopenmpi-dev python3-dev python3-pip wget + +COPY . /usr/src/tgi +RUN chmod +x scripts/install_tensorrt.sh && scripts/install_tensorrt.sh +RUN cmake -G Ninja -B build -DTRT_LIB_DIR=/usr/local/tensorrt/lib -DTRT_INCLUDE_DIR=/usr/local/tensorrt/include . +RUN cmake --build build --parallel -t tgi_trtllm_backend_impl + +# All the tooling for Rust +FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef +WORKDIR /usr/src + +# Include CUDA related libraries and tools to the Rust based image +COPY --from=cuda-builder /usr/local/cuda /usr/local/cuda +COPY --from=cuda-builder /usr/local/tensorrt /usr/local/tensorrt +COPY --from=cuda-builder /usr/src/tgi/backends/trtllm/build /usr/local/tgi/trtllm/build +ENV PATH=/usr/local/cuda/bin:$PATH +ENV LD_LIBRARY_PATH=/usr/local/tensorrt/lib:$LD_LIBRARY_PATH + +RUN apt update && apt install -y cmake git gcc g++ ninja-build libopenmpi3 diff --git a/Dockerfile_amd b/Dockerfile_amd index 0aebeee5..51231638 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -11,6 +11,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo chef prepare --recipe-path recipe.json @@ -33,6 +34,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo build --profile release-opt diff --git a/Dockerfile_intel b/Dockerfile_intel index 6a803a32..d20f0a01 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -12,6 +12,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo chef prepare --recipe-path recipe.json @@ -34,6 +35,7 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo build --profile release-opt diff --git a/Makefile b/Makefile index a1399b6d..3068a06f 100644 --- a/Makefile +++ b/Makefile @@ -5,13 +5,13 @@ install-server-cpu: cd server && make install-server install-router: - cd router && cargo install --path . + cargo install --path backends/v3/ install-launcher: - cd launcher && cargo install --path . + cargo install --path launcher/ install-benchmark: - cd benchmark && cargo install --path . + cargo install --path benchmark/ install: install-server install-router install-launcher diff --git a/router/client/Cargo.toml b/backends/client/Cargo.toml similarity index 100% rename from router/client/Cargo.toml rename to backends/client/Cargo.toml diff --git a/router/client/build.rs b/backends/client/build.rs similarity index 100% rename from router/client/build.rs rename to backends/client/build.rs diff --git a/router/client/src/lib.rs b/backends/client/src/lib.rs similarity index 100% rename from router/client/src/lib.rs rename to backends/client/src/lib.rs diff --git a/router/client/src/v2/client.rs b/backends/client/src/v2/client.rs similarity index 100% rename from router/client/src/v2/client.rs rename to backends/client/src/v2/client.rs diff --git a/router/client/src/v2/mod.rs b/backends/client/src/v2/mod.rs similarity index 100% rename from router/client/src/v2/mod.rs rename to backends/client/src/v2/mod.rs diff --git a/router/client/src/v2/sharded_client.rs b/backends/client/src/v2/sharded_client.rs similarity index 100% rename from router/client/src/v2/sharded_client.rs rename to backends/client/src/v2/sharded_client.rs diff --git a/router/client/src/v3/client.rs b/backends/client/src/v3/client.rs similarity index 100% rename from router/client/src/v3/client.rs rename to backends/client/src/v3/client.rs diff --git a/router/client/src/v3/mod.rs b/backends/client/src/v3/mod.rs similarity index 100% rename from router/client/src/v3/mod.rs rename to backends/client/src/v3/mod.rs diff --git a/router/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs similarity index 100% rename from router/client/src/v3/sharded_client.rs rename to backends/client/src/v3/sharded_client.rs diff --git a/router/grpc-metadata/Cargo.toml b/backends/grpc-metadata/Cargo.toml similarity index 100% rename from router/grpc-metadata/Cargo.toml rename to backends/grpc-metadata/Cargo.toml diff --git a/router/grpc-metadata/src/lib.rs b/backends/grpc-metadata/src/lib.rs similarity index 100% rename from router/grpc-metadata/src/lib.rs rename to backends/grpc-metadata/src/lib.rs diff --git a/backends/trtllm/CMakeLists.txt b/backends/trtllm/CMakeLists.txt new file mode 100644 index 00000000..425b2d7b --- /dev/null +++ b/backends/trtllm/CMakeLists.txt @@ -0,0 +1,63 @@ +cmake_minimum_required(VERSION 3.20) + +project(tgi-trtllm-backend VERSION 1.0.0) +set(CMAKE_CXX_STANDARD 20) + +include(FetchContent) +include(ExternalProject) + +option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF) +option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF) +set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support") +set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located") +set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located") +set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located") + +# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features +find_package(CUDAToolkit 12.5 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml) + +#### External dependencies #### +include(cmake/fmt.cmake) +include(cmake/json.cmake) +include(cmake/spdlog.cmake) +include(cmake/trtllm.cmake) + +# Let's build TRTLLM as part of CMake +add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..") + +# Tell CMake to need try to override the RPATH for executorWorker as it has not information on how to do so +set_target_properties(executorWorker PROPERTIES SKIP_BUILD_RPATH TRUE) + +# TGI TRTLLM Backend definition +add_library(tgi_trtllm_backend_impl STATIC include/backend.h lib/backend.cpp include/hardware.h) +include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR}) +target_include_directories(tgi_trtllm_backend_impl PRIVATE + $ + $ +) +target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include") +target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper CUDA::cudart CUDA::nvml) +target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt) + +# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back +install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker) +install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB) + +#### Unit Tests #### +if (${TGI_TRTLLM_BACKEND_BUILD_TESTS}) + message(STATUS "Building tests") + FetchContent_Declare( + Catch2 + GIT_REPOSITORY https://github.com/catchorg/Catch2 + GIT_TAG v3.6.0 + ) + FetchContent_MakeAvailable(Catch2) + + # add_executable(tgi_trtllm_backend_tests tests/infer_test.cpp) + # target_link_libraries(tgi_trtllm_backend_tests PRIVATE tgi_trtllm_backend_impl Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt CUDA::cudart CUDA::nvml) + + list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras) + include(CTest) + include(Catch) + # catch_discover_tests(tgi_trtllm_backend_tests) +endif () diff --git a/backends/trtllm/Cargo.toml b/backends/trtllm/Cargo.toml new file mode 100644 index 00000000..7079d3d1 --- /dev/null +++ b/backends/trtllm/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "text-generation-backends-trtllm" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[dependencies] +async-trait = "0.1" +async-stream = "0.3" +cxx = "1.0" +text-generation-router = { path = "../../router" } +tokenizers = { version = "0.19", features = ["hf-hub"] } +tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokio-stream = "0.1.15" +clap = { version = "4.5", features = ["derive"] } +thiserror = "1.0.62" +tracing = "0.1" +tracing-opentelemetry = "0.24" +tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } +log = { version = "0.4", features = [] } + +[build-dependencies] +cmake = "0.1" +cxx-build = { version = "1.0", features = ["parallel"] } +pkg-config = "0.3" diff --git a/backends/trtllm/Dockerfile b/backends/trtllm/Dockerfile new file mode 100644 index 00000000..60ad03f7 --- /dev/null +++ b/backends/trtllm/Dockerfile @@ -0,0 +1,100 @@ +ARG CUDA_ARCH_LIST="75-real;80-real;86-real;89-real;90-real" +ARG OMPI_VERSION="4.1.6" + +# Build dependencies resolver stage +FROM lukemathwalker/cargo-chef:latest AS chef +WORKDIR /usr/src/text-generation-inference + +FROM chef AS planner +COPY . . +RUN cargo chef prepare --recipe-path recipe.json + +# CUDA dependent dependencies resolver stage +FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu22.04 AS cuda-builder + +RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ + --mount=type=cache,target=/var/lib/apt,sharing=locked \ + apt update && apt install -y \ + build-essential \ + cmake \ + curl \ + gcc \ + g++ \ + git \ + git-lfs \ + libssl-dev \ + ninja-build \ + pkg-config \ + python3 \ + python3-setuptools \ + tar \ + wget + +ENV TGI_INSTALL_PREFIX=/usr/local/tgi +ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt + +# Install OpenMPI +FROM cuda-builder AS mpi-builder +ARG OMPI_VERSION + +ENV OMPI_TARBALL_FILENAME="openmpi-$OMPI_VERSION.tar.bz2" +RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILENAME" -P /opt/src && \ + mkdir /usr/src/mpi && \ + tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \ + cd /usr/src/mpi && \ + ./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --without-slurm && \ + make -j all && \ + make install && \ + rm -rf "/opt/src/$OMPI_TARBALL_FILENAME" + +# Install TensorRT +FROM cuda-builder AS trt-builder +COPY backends/trtllm/scripts/install_tensorrt.sh /opt/install_tensorrt.sh +RUN chmod +x /opt/install_tensorrt.sh && \ + /opt/install_tensorrt.sh + +# Build Backend +FROM cuda-builder AS tgi-builder +WORKDIR /usr/src/text-generation-inference + +# Install Rust +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \ + chmod -R a+w /root/.rustup && \ + chmod -R a+w /root/.cargo + +ENV PATH="/root/.cargo/bin:$PATH" +RUN cargo install cargo-chef + +# Cache dependencies +COPY --from=planner /usr/src/text-generation-inference/recipe.json . +RUN cargo chef cook --release --recipe-path recipe.json + +# Build actual TGI +ARG CUDA_ARCH_LIST +ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt:$CMAKE_PREFIX_PATH" +ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH" +ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig:$PKG_CONFIG_PATH" + +COPY . . +COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt +COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi +RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \ + CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release --bin text-generation-backends-trtllm + +FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04 AS runtime +WORKDIR /usr/local/tgi/bin + +ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH" + +COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi +COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt +COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi +COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher + +FROM runtime + +LABEL co.huggingface.vendor="Hugging Face Inc." +LABEL org.opencontainers.image.authors="hardware@hf.co" + +ENTRYPOINT ["./text-generation-launcher"] +CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"] diff --git a/backends/trtllm/README.md b/backends/trtllm/README.md new file mode 100644 index 00000000..94064504 --- /dev/null +++ b/backends/trtllm/README.md @@ -0,0 +1,46 @@ +# Text Generation Inference - TensorRT-LLM Backend Implementation + +## Description + +This folder provides the sources of the TensorRT-LLM backend implementation powered by TensorRT-LLM Executor new API + +## Simplified Request Sequence + +```mermaid +sequenceDiagram + actor User + participant TextGenerationInference.HttpServer + participant TextGenerationInference.TensorRtLlmBackend + participant TextGenerationInference.TensorRtLlmWorkerThread + participant TensorRtLlm.Executor + participant Nvidia.Gpu + User ->> TextGenerationInference.HttpServer: POST /generate + TextGenerationInference.HttpServer ->> TextGenerationInference.TensorRtLlmBackend: Validate and forward inputs & parameters + TextGenerationInference.TensorRtLlmBackend ->> TextGenerationInference.TensorRtLlmWorkerThread: Allocate a new context and spawn a new thread to handle the request + TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Submit the request to the In-Flight Batcher + activate Nvidia.Gpu + TensorRtLlm.Executor ->> Nvidia.Gpu: Add the request to the poll for execution + TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Response with an unique request identifier + rect rgb(10, 92, 54) + loop every 100us + rect rgb(15, 81, 50) + alt Acquire lock to query executor + TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Poll request number of new token(s) generated + else There are new generated tokens + TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Retrieve newly generated tokens + TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Return decoded token information and potential error (omitted) + rect rgb(11, 110, 79) + alt Generated token is final + TensorRtLlm.Executor ->> Nvidia.Gpu: Remove request from the scheduler and from the GPU + TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream the remaining decoded tokens and flush the connection + else Generated token is not final + TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream token back to the user as they get decoded + end + end + end + end + deactivate Nvidia.Gpu + end + end + +``` diff --git a/backends/trtllm/build.rs b/backends/trtllm/build.rs new file mode 100644 index 00000000..08638262 --- /dev/null +++ b/backends/trtllm/build.rs @@ -0,0 +1,150 @@ +use cxx_build::CFG; +use pkg_config; +use std::env; +use std::env::consts::ARCH; +use std::path::{absolute, PathBuf}; + +const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"]; +const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST"); +const CUDA_REQUIRED_VERSION: &str = "12.5"; +const MPI_REQUIRED_VERSION: &str = "4.1"; +const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX"); +const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR"); +const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR"); + +// Dependencies +const BACKEND_DEPS: [&str; 2] = ["tgi_trtllm_backend_impl", "tgi_trtllm_backend"]; +const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"]; +const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [ + ("dylib", "tensorrt_llm"), + ("static", "tensorrt_llm_executor_static"), + ("dylib", "tensorrt_llm_nvrtc_wrapper"), + ("dylib", "nvinfer_plugin_tensorrt_llm"), + ("dylib", "decoder_attention"), +]; + +macro_rules! probe { + ($name: expr, $version: expr) => { + if let Err(_) = pkg_config::probe_library($name) { + pkg_config::probe_library(&format!("{}-{}", $name, $version)) + .expect(&format!("Failed to locate {}", $name)); + } + }; +} + +fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) { + // Build the backend implementation through CMake + let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi"); + let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt"); + let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("90-real"); // Hopper by default + + let mut install_path = PathBuf::from(install_path); + if !install_path.is_absolute() { + install_path = absolute(out_dir).expect("cannot happen").join(install_path); + } + + let _ = cmake::Config::new(".") + .uses_cxx11() + .generator("Ninja") + .profile(match is_debug { + true => "Debug", + false => "Release", + }) + .env("OPT_LEVEL", opt_level) + .define("CMAKE_INSTALL_PREFIX", &install_path) + .define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc") + .define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list) + .define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path) + .build(); + + // Additional transitive CMake dependencies + let deps_folder = out_dir.join("build").join("_deps"); + for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES { + let dep_name = match is_debug { + true => format!("{}d", dependency), + false => String::from(dependency), + }; + let dep_path = deps_folder.join(format!("{}-build", dependency)); + println!("cargo:rustc-link-search={}", dep_path.display()); + println!("cargo:rustc-link-lib=static={}", dep_name); + } + + // Emit linkage information from the artifacts we just built + let install_lib_path = install_path.join("lib"); + + println!( + r"cargo:warning=Adding link search path: {}", + install_lib_path.display() + ); + println!(r"cargo:rustc-link-search={}", install_lib_path.display()); + + (PathBuf::from(install_path), deps_folder) +} + +fn build_ffi_layer(deps_folder: &PathBuf) { + CFG.include_prefix = "backends/trtllm"; + cxx_build::bridge("src/lib.rs") + .static_flag(true) + .include(deps_folder.join("fmt-src").join("include")) + .include(deps_folder.join("spdlog-src").join("include")) + .include(deps_folder.join("json-src").join("include")) + .include(deps_folder.join("trtllm-src").join("cpp").join("include")) + .include("/usr/local/cuda/include") + .include("/usr/local/tensorrt/include") + .file("src/ffi.cpp") + .std("c++20") + .compile("tgi_trtllm_backend"); + + println!("cargo:rerun-if-changed=CMakeLists.txt"); + println!("cargo:rerun-if-changed=include/backend.h"); + println!("cargo:rerun-if-changed=lib/backend.cpp"); + println!("cargo:rerun-if-changed=include/ffi.h"); + println!("cargo:rerun-if-changed=src/ffi.cpp"); +} + +fn main() { + // Misc variables + let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); + let build_profile = env::var("PROFILE").unwrap(); + let (is_debug, opt_level) = match build_profile.as_ref() { + "debug" => (true, "0"), + _ => (false, "3"), + }; + + // Build the backend + let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir); + + // Build the FFI layer calling the backend above + build_ffi_layer(&deps_folder); + + // Emit linkage search path + probe!("ompi", MPI_REQUIRED_VERSION); + + // Probe CUDA & co. with pkg-config + CUDA_TRANSITIVE_DEPS.iter().for_each(|name| { + probe!(name, CUDA_REQUIRED_VERSION); + }); + + // NCCL is slightly trickier because it might not have a pkgconfig installed + let nccl_library_path_default = format!("/usr/local/{}-linux-gnu", ARCH); + let nccl_library_path = NCCL_ROOT_DIR.unwrap_or(&nccl_library_path_default); + println!(r"cargo:rustc-link-search=native={}", nccl_library_path); + println!("cargo:rustc-link-lib=dylib=nccl"); + + // TensorRT + let tensort_library_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt/lib"); + println!(r"cargo:rustc-link-search=native={}", tensort_library_path); + println!("cargo:rustc-link-lib=dylib=nvinfer"); + + // TensorRT-LLM + TENSORRT_LLM_TRANSITIVE_DEPS + .iter() + .for_each(|(link_type, name)| { + println!("cargo:rustc-link-lib={}={}", link_type, name); + }); + + // Backend + BACKEND_DEPS.iter().for_each(|name| { + println!("cargo:rustc-link-lib=static={}", name); + }); +} diff --git a/backends/trtllm/cmake/fmt.cmake b/backends/trtllm/cmake/fmt.cmake new file mode 100644 index 00000000..f94a9c56 --- /dev/null +++ b/backends/trtllm/cmake/fmt.cmake @@ -0,0 +1,6 @@ +FetchContent_Declare( + fmt + GIT_REPOSITORY https://github.com/fmtlib/fmt + GIT_TAG 11.0.1 +) +FetchContent_MakeAvailable(fmt) diff --git a/backends/trtllm/cmake/json.cmake b/backends/trtllm/cmake/json.cmake new file mode 100644 index 00000000..29e5753b --- /dev/null +++ b/backends/trtllm/cmake/json.cmake @@ -0,0 +1,5 @@ +fetchcontent_declare( + json + URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz +) +fetchcontent_makeavailable(json) diff --git a/backends/trtllm/cmake/spdlog.cmake b/backends/trtllm/cmake/spdlog.cmake new file mode 100644 index 00000000..c4ee5c97 --- /dev/null +++ b/backends/trtllm/cmake/spdlog.cmake @@ -0,0 +1,17 @@ +set(SPDLOG_USE_FMT ON) +set(SPDLOG_BUILD_SHARED OFF) +set(SPDLOG_FMT_EXTERNAL ON) + +# Define the level at which SPDLOG_ compilation level is defined +if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") + add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG) +else () + add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO) +endif () + +fetchcontent_declare( + spdlog + GIT_REPOSITORY https://github.com/gabime/spdlog.git + GIT_TAG v1.14.1 +) +fetchcontent_makeavailable(spdlog) diff --git a/backends/trtllm/cmake/trtllm.cmake b/backends/trtllm/cmake/trtllm.cmake new file mode 100644 index 00000000..e59ad4cf --- /dev/null +++ b/backends/trtllm/cmake/trtllm.cmake @@ -0,0 +1,42 @@ +set(TRT_INCLUDE_DIR ${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR}) +set(TRT_LIB_DIR ${TGI_TRTLLM_BACKEND_TRT_LIB_DIR}) + +set(USE_CXX11_ABI ON) +set(BUILD_PYT OFF) +set(BUILD_PYBIND OFF) +set(BUILD_MICRO_BENCHMARKS OFF) +set(BUILD_BENCHMARKS OFF) +set(BUILD_TESTS OFF) +set(CMAKE_CUDA_ARCHITECTURES ${TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST}) + +message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}") + +if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") + set(FAST_BUILD ON) + set(NVTX_DISABLE OFF) +else () + set(FAST_BUILD OFF) + set(FAST_MATH ON) + set(NVTX_DISABLE ON) +endif () + +fetchcontent_declare( + trtllm + GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git + GIT_TAG a681853d3803ee5893307e812530b5e7004bb6e1 + GIT_SHALLOW FALSE +) +fetchcontent_makeavailable(trtllm) + +message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}") +execute_process(COMMAND git lfs install WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/") +execute_process(COMMAND git lfs pull WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/") + +# TRTLLM use a JIT based *precompiled* library to generate some specific kernels, we are generating the path to this one here +set(TRTLLM_NVRTC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_nvrtc_wrapper${CMAKE_SHARED_LIBRARY_SUFFIX}" CACHE INTERNAL "nvrtc wrapper library name") +set(TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_NVRTC_LIBRARY_NAME}" + CACHE INTERNAL "nvrtc wrapper library path") + +# The same Executor Static library +set(TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_executor_static${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE INTERNAL "executor_static library name") +set(TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/executor/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME}" CACHE INTERNAL "executor_static library path") diff --git a/backends/trtllm/cmake/utils/detect_cuda_arch.cu b/backends/trtllm/cmake/utils/detect_cuda_arch.cu new file mode 100644 index 00000000..e69de29b diff --git a/backends/trtllm/include/backend.h b/backends/trtllm/include/backend.h new file mode 100644 index 00000000..7990e76b --- /dev/null +++ b/backends/trtllm/include/backend.h @@ -0,0 +1,121 @@ +// +// Created by Morgan Funtowicz on 6/30/24. +// + +#ifndef TGI_TRTLLM_BACKEND_H +#define TGI_TRTLLM_BACKEND_H + +#include +#include +#include +#include + +#include + +#include +#include +#include + +using json = nlohmann::json; +namespace tle = tensorrt_llm::executor; + +namespace huggingface::tgi::backends { + using RequestId = tle::IdType; + using TokenId = tle::TokenIdType; + + /** + * Initialize all the components required by TRTLLM. + * It is required to call this function before attempting to load any engine + */ + void InitializeBackend(); + + /** + * + * @param config TensorRT-LLM configuration object + * @param workerPath Path to the "executorWorker" provided by TensorRT-LLM when using orchestrator mode + * @return + */ + tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath); + + /** + * Get the sampling configuration from the parameters provided by TGI + * @param topK + * @param topP + * @param temperature + * @param repetition_penalty + * @param frequency_penalty + * @param seed + * @return + */ + tle::SamplingConfig GetSamplingConfig( + uint32_t topK, + float_t topP, + float_t temperature, + float_t repetition_penalty, + float_t frequency_penalty, + uint64_t seed + ); + + /** + * + */ + class TensorRtLlmBackend { + private: + const json config; + tle::Executor executor; + + public: + explicit TensorRtLlmBackend( + const std::filesystem::path &engineFolder, + const std::filesystem::path &executorWorker + ); + + /** + * Indicate if the backend is ready to accept incoming request + * @return true if ready, false otherwise + */ + [[nodiscard]] bool IsReady() const; + + /** + * Query the executor for the number of token available for pulling + * @return + */ + [[nodiscard]] size_t NumResponsesReady() const; + + /** + * Submit a new generation task to the executor + * @param tokens + * @param topK + * @param topP + * @param temperature + * @param repetition_penalty + * @param frequency_penalty + * @param seed + * @return Request id related to this generation for reference + */ + [[nodiscard]] RequestId Submit( + const std::vector &tokens, + int32_t topK, + float_t topP, + float_t temperature, + float_t repetition_penalty, + float_t frequency_penalty, + uint64_t seed + ); + + /** + * + * @param requestId The request id to poll the generation results + * @return + */ + std::vector Poll(RequestId requestId); + + /** + * Stop the underlying executor + */ + void Shutdown(); + }; +} + + +#endif //TGI_TRTLLM_BACKEND_H diff --git a/backends/trtllm/include/ffi.h b/backends/trtllm/include/ffi.h new file mode 100644 index 00000000..fe0be9fc --- /dev/null +++ b/backends/trtllm/include/ffi.h @@ -0,0 +1,75 @@ +// +// Created by mfuntowicz on 7/11/24. +// + +#ifndef TGI_TRTLLM_BACKEND_FFI_H +#define TGI_TRTLLM_BACKEND_FFI_H + +#include +#include "backend.h" + +namespace huggingface::tgi::backends { + class TensorRtLlmBackendImpl; +} + +#include "backends/trtllm/src/lib.rs.h" + + +namespace huggingface::tgi::backends { + +// struct GenerationContext; + + class TensorRtLlmBackendImpl : public TensorRtLlmBackend { + public: + /*** + * + * @param engineFolder + * @param executorWorker + */ + TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker); + + /*** + * + * @return + */ + bool IsReady() const; + + /*** + * + * @param tokens + * @param topK + * @param topP + * @param temperature + * @param repetition_penalty + * @param frequency_penalty + * @param seed + * @return + */ + [[nodiscard("returned request id should be used to refer to the request's generation result later on")]] + uint64_t + Submit(rust::Slice tokens, int32_t topK, float_t topP, float_t temperature, + float_t repetition_penalty, float_t frequency_penalty, uint64_t seed); + + /*** + * + * @param requestId + * @param ctx + * @param callback + * @return + */ + size_t StreamTokens( + const RequestId requestId, + huggingface::tgi::backends::GenerationContext *ctx, + rust::Fn callback); + }; + + /*** + * + * @param engineFolder + * @return + */ + std::unique_ptr CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker); +} + +#endif //TGI_TRTLLM_BACKEND_FFI_H diff --git a/backends/trtllm/include/hardware.h b/backends/trtllm/include/hardware.h new file mode 100644 index 00000000..da0bf4f3 --- /dev/null +++ b/backends/trtllm/include/hardware.h @@ -0,0 +1,59 @@ +// +// Created by mfuntowicz on 7/23/24. +// + +#ifndef TGI_TRTLLM_BACKEND_HARDWARE_H +#define TGI_TRTLLM_BACKEND_HARDWARE_H + +#include +#include +#include +#include +#include + +namespace huggingface::hardware::cuda { + +#define AMPERE_SM_MAJOR 8 +#define HOPPER_SM_MAJOR 8 + + /** + * Store information about the version of the CUDA Compute Capabilities detected on the device + */ + struct CudaComputeCapabilities { + int32_t major; + int32_t minor; + + [[nodiscard]] constexpr bool isPostAmpere() const { return major >= AMPERE_SM_MAJOR; } + + [[nodiscard]] constexpr bool isPostHopper() const { return major >= HOPPER_SM_MAJOR; } + }; + + CudaComputeCapabilities GetCudaComputeCapabilities() { + // Get the compute capabilities of the current hardware + nvmlDevice_t device; + CudaComputeCapabilities capabilities{0, 0}; + if (nvmlDeviceGetHandleByIndex_v2(0, &device) == NVML_SUCCESS) { + SPDLOG_DEBUG("Successfully acquired nvmlDevice_t = 0"); + if (nvmlDeviceGetCudaComputeCapability(device, &capabilities.major, &capabilities.minor) == NVML_SUCCESS) { + SPDLOG_INFO("Detected sm_{:d}{:d} compute capabilities", capabilities.major, capabilities.minor); + } + } + + return capabilities; + } + + /** + * Return the number of GPU detected. If no GPU is detected, return size_t::max() + * @return + */ + std::optional GetNumDevices() { + uint32_t numGpus = 0; + if (nvmlDeviceGetCount_v2(&numGpus) == NVML_SUCCESS) { + return std::optional(numGpus); + } else { + return std::nullopt; + } + } +} + +#endif //TGI_TRTLLM_BACKEND_HARDWARE_H diff --git a/backends/trtllm/lib/backend.cpp b/backends/trtllm/lib/backend.cpp new file mode 100644 index 00000000..c066a6d6 --- /dev/null +++ b/backends/trtllm/lib/backend.cpp @@ -0,0 +1,146 @@ +#include + +#include +#include +#include + +#include "backend.h" +#include "hardware.h" + +void huggingface::tgi::backends::InitializeBackend() { + SPDLOG_INFO("Initializing Backend..."); + nvmlInit_v2(); + initTrtLlmPlugins(); + + const auto numGpus = huggingface::hardware::cuda::GetNumDevices(); + if (numGpus.has_value()) { + SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value()); + } else { + SPDLOG_WARN("Failed to detected Nvidia GPU(s) on the system"); + } +} + +[[nodiscard]] +tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) { + tle::ExecutorConfig execConfig(1); + + // Retrieve the compute capabilities to enable some options at runtime + const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities(); + + // Single engine (TP = PP = 1) -> using leader mode (no MPI involved) + if (config["/pretrained_config/mapping/world_size"_json_pointer].get() == 1) { + SPDLOG_INFO("Detected single engine deployment, using leader mode"); + execConfig.setParallelConfig(tle::ParallelConfig( + tle::CommunicationType::kMPI, + tle::CommunicationMode::kLEADER, + std::nullopt, + std::nullopt, + std::nullopt + )); + } else { // Multiple engines -> using orchestrator mode (MPI involved) + SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode"); + execConfig.setParallelConfig(tle::ParallelConfig( + tle::CommunicationType::kMPI, + tle::CommunicationMode::kORCHESTRATOR, + std::nullopt, + std::nullopt, + tle::OrchestratorConfig(true, workerPath, nullptr, true) + )); + } + + // Define some configuration variables + execConfig.setKvCacheConfig(tle::KvCacheConfig(true)); + execConfig.setEnableChunkedContext(computeCapabilities.isPostAmpere()); + return execConfig; +} + +tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig( + uint32_t topK, + float_t topP, + float_t temperature, + float_t repetition_penalty, + float_t frequency_penalty, + uint64_t seed) { + return tle::SamplingConfig( + 1, // TGI only use a single beam + topK, + topP, + std::nullopt, + std::nullopt, + std::nullopt, + seed, + temperature, + temperature, + std::nullopt, + repetition_penalty, + std::nullopt, + frequency_penalty + ); +} + +huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend( + const std::filesystem::path &enginesFolder, + const std::filesystem::path &executorWorker +) : + config(json::parse(std::ifstream(enginesFolder / "config.json"))), + executor( + enginesFolder, + tensorrt_llm::executor::ModelType::kDECODER_ONLY, + GetExecutorConfig(config, executorWorker.string() + )) { + SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref()); +} + +bool huggingface::tgi::backends::TensorRtLlmBackend::IsReady() const { + return executor.canEnqueueRequests(); +} + +[[nodiscard("Returned number of requests needs to be consumed")]] +size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const { + return executor.getNumResponsesReady(); +} + +[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]] +tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit( + const std::vector &tokens, + const int32_t topK, + const float_t topP, + const float_t temperature, + const float_t repetition_penalty, + const float_t frequency_penalty, + const uint64_t seed +) { +#ifdef NDEBUG + SPDLOG_DEBUG( + FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"), + tokens.size(), + executor.getLatestIterationStats().back().numActiveRequests + ); +#else + SPDLOG_DEBUG( + FMT_STRING("Submitting inference [{}] to the executor ({:d} already in-flight)"), + fmt::join(tokens, ", "), + executor.getLatestIterationStats().front().numActiveRequests + ); +#endif + + const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get(); + const auto maxNewTokens = static_cast(std::max(1ul, maxNumTokens - tokens.size())); + + const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed); + const auto output = tle::OutputConfig(true, false, false, true, false); + return executor.enqueueRequest( + tle::Request{tokens, maxNewTokens, true, sampling, output}); +} + +[[nodiscard("Generated tokens result must be used")]] +std::vector huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType requestId) { + SPDLOG_DEBUG(FMT_STRING("Polling status for request {:d}"), requestId); + return executor.awaitResponses(requestId); +} + + +void huggingface::tgi::backends::TensorRtLlmBackend::Shutdown() { + SPDLOG_INFO("Shutting down executor"); + executor.shutdown(); +} diff --git a/backends/trtllm/scripts/install_tensorrt.sh b/backends/trtllm/scripts/install_tensorrt.sh new file mode 100755 index 00000000..e0e2dd17 --- /dev/null +++ b/backends/trtllm/scripts/install_tensorrt.sh @@ -0,0 +1,111 @@ +#!/bin/bash + +set -ex + +TRT_VER="10.2.0.19" +CUDA_VER="12.5" +CUDNN_VER="9.2.1.18-1" +NCCL_VER="2.22.3-1+cuda12.5" +CUBLAS_VER="12.5.3.2-1" +NVRTC_VER="12.5.82-1" + +for i in "$@"; do + case $i in + --TRT_VER=?*) TRT_VER="${i#*=}";; + --CUDA_VER=?*) CUDA_VER="${i#*=}";; + --CUDNN_VER=?*) CUDNN_VER="${i#*=}";; + --NCCL_VER=?*) NCCL_VER="${i#*=}";; + --CUBLAS_VER=?*) CUBLAS_VER="${i#*=}";; + *) ;; + esac + shift +done + +NVCC_VERSION_OUTPUT=$(nvcc --version) +if [[ $(echo $NVCC_VERSION_OUTPUT | grep -oP "\d+\.\d+" | head -n 1) != ${CUDA_VER} ]]; then + echo "The version of pre-installed CUDA is not equal to ${CUDA_VER}." + exit 1 +fi + +install_ubuntu_requirements() { + apt-get update && apt-get install -y --no-install-recommends gnupg2 curl ca-certificates + ARCH=$(uname -m) + if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi + if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi + curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/${ARCH}/cuda-keyring_1.0-1_all.deb + dpkg -i cuda-keyring_1.0-1_all.deb + + apt-get update + if [[ $(apt list --installed | grep libcudnn9) ]]; then + apt-get remove --purge -y --allow-change-held-packages libcudnn9* + fi + if [[ $(apt list --installed | grep libnccl) ]]; then + apt-get remove --purge -y --allow-change-held-packages libnccl* + fi + if [[ $(apt list --installed | grep libcublas) ]]; then + apt-get remove --purge -y --allow-change-held-packages libcublas* + fi + if [[ $(apt list --installed | grep cuda-nvrtc-dev) ]]; then + apt-get remove --purge -y --allow-change-held-packages cuda-nvrtc-dev* + fi + CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g') + apt-get install -y --no-install-recommends libcudnn9-cuda-12=${CUDNN_VER} libcudnn9-dev-cuda-12=${CUDNN_VER} + apt-get install -y --no-install-recommends libnccl2=${NCCL_VER} libnccl-dev=${NCCL_VER} + apt-get install -y --no-install-recommends libcublas-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} libcublas-dev-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} + # NVRTC static library doesn't exist in NGC PyTorch container. + NVRTC_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g') + apt-get install -y --no-install-recommends cuda-nvrtc-dev-${NVRTC_CUDA_VERSION}=${NVRTC_VER} + apt-get clean + rm -rf /var/lib/apt/lists/* +} + +install_centos_requirements() { + CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g') + yum -y update + yum -y install epel-release + yum remove -y libnccl* && yum -y install libnccl-${NCCL_VER} libnccl-devel-${NCCL_VER} + yum remove -y libcublas* && yum -y install libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER} libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER} + yum clean all +} + +install_tensorrt() { + #PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))') + #PARSED_PY_VERSION=$(echo "${PY_VERSION//./}") + TRT_CUDA_VERSION="12.5" + + if [ -z "$RELEASE_URL_TRT" ];then + ARCH=${TRT_TARGETARCH} + if [ -z "$ARCH" ];then ARCH=$(uname -m);fi + if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi + if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi + if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi + if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-22.04" && OS="ubuntu-22.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi + RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.2.0/tars/TensorRT-${TRT_VER}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz + fi + wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar + tar -xf /tmp/TensorRT.tar -C /usr/local/ + mv /usr/local/TensorRT-${TRT_VER} /usr/local/tensorrt + # pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl + rm -rf /tmp/TensorRT.tar +} + +# Install base packages depending on the base OS +ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"') +case "$ID" in + debian) + install_ubuntu_requirements + install_tensorrt + ;; + ubuntu) + install_ubuntu_requirements + install_tensorrt + ;; + centos) + install_centos_requirements + install_tensorrt + ;; + *) + echo "Unable to determine OS..." + exit 1 + ;; +esac diff --git a/backends/trtllm/src/backend.rs b/backends/trtllm/src/backend.rs new file mode 100644 index 00000000..b26d06a6 --- /dev/null +++ b/backends/trtllm/src/backend.rs @@ -0,0 +1,329 @@ +use std::future::Future; +use std::path::Path; +use std::pin::{pin, Pin}; +use std::str::FromStr; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, OnceLock}; +use std::task::{Context, Poll}; +use std::time::Duration; + +use async_trait::async_trait; +use cxx::UniquePtr; +use log::{error, warn}; +use tokenizers::Tokenizer; +use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; +use tokio::sync::RwLock; +use tokio::time::{sleep, Instant}; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio_stream::{Stream, StreamExt}; +use tracing::{instrument, span, Level}; + +use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; +use text_generation_router::validation::ValidationError::UnsupportedModality; +use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError}; +use text_generation_router::{FinishReason, Token}; + +use crate::errors::TensorRtLlmBackendError; +use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl}; + +// Value used to poll the state of the generation stream +static POLLING_INTERVAL_US: OnceLock = OnceLock::new(); + +type InferResult = Result; + +pub(crate) struct Generation { + executor: Arc>>, + done: Arc, +} + +/// Holds the user provided input to be executed along with a channel allowing +/// to bubble up all the generated tokens for that tokens the to end stream. +pub struct GenerationContext { + sender: UnboundedSender>, + tokenizer: Arc, + tokens: Vec, + done: Arc, + queued: Instant, + start: Option, +} + +impl Stream for Generation { + type Item = usize; + + fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { + let interval = POLLING_INTERVAL_US.get_or_init(|| { + u64::from_str(option_env!("TRTLLM_BACKEND_POLLING_INTERVAL_US").unwrap_or("100")) + .expect("Invalid value provided for envvar POLLING_INTERVAL_US") + }); + + if !self.done.load(Ordering::Relaxed) { + let backend = pin!(self.executor.read()); + let status = match backend.poll(ctx) { + Poll::Ready(executor_r) => { + let ready = executor_r.num_responses_ready(); + if ready == 0 { + Poll::Pending + } else { + Poll::Ready(Some(ready)) + } + } + Poll::Pending => Poll::Pending, + }; + + let waker = ctx.waker().clone(); + tokio::spawn(async { + sleep(Duration::from_micros(*interval)).await; + waker.wake(); + }); + + status + } else { + Poll::Ready(None) // end of stream + } + } + + fn size_hint(&self) -> (usize, Option) { + (1, None) + } +} + +unsafe impl Send for TensorRtLlmBackendImpl {} +unsafe impl Sync for TensorRtLlmBackendImpl {} + +/// Implements the logic to execute generation with TensorRT-LLM executor API in background +pub struct TensorRtLlmBackend { + tokenizer: Arc, + + // Backing the backend behind a RwLock to allow concurrent read access to retrieve + // the number of available tokens (read only) in the Generation stream + backend: Arc>>, +} + +impl TensorRtLlmBackend { + pub fn new + Send + 'static, PP: AsRef + Send + 'static>( + tokenizer: Tokenizer, + engine_folder: P, + executor_worker_path: PP, + ) -> Result { + Ok(TensorRtLlmBackend { + tokenizer: Arc::new(tokenizer), + backend: Arc::new(RwLock::new(create_tensorrt_llm_backend( + engine_folder.as_ref().to_str().unwrap(), + executor_worker_path.as_ref().to_str().unwrap(), + ))), + }) + } + + fn validate(request: &ValidGenerateRequest) -> InferResult<&String> { + if request.top_n_tokens > 1 { + return Err(InferError::ValidationError( + ValidationError::TopNTokensDisabled, + )); + } + + // TODO: Is it really needed? How can it be validated before? + if request.parameters.grammar.is_some() { + return Err(InferError::ValidationError(ValidationError::Grammar)); + } + + match request.inputs.len() { + 0 => Err(InferError::ValidationError(ValidationError::EmptyInput)), + 2.. => Err(InferError::GenerationError( + "TensorRT-LLM backend don't support multi-chunk".into(), + )), + 1 => match request.inputs.first().expect("Single item-chunk") { + Chunk::Text(text) => Ok(text), + Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))), + }, + } + } + + fn generate( + &self, + sender: UnboundedSender>, + tokens: Vec, + top_k: u32, + top_p: f32, + temperature: f32, + repetition_penalty: f32, + frequency_penalty: f32, + seed: u64, + ) { + let tokenizer = Arc::clone(&self.tokenizer); + let executor = Arc::clone(&self.backend); + + // Let's push this in async context + tokio::spawn(async move { + // Define the generation state + let mut generation = Generation { + executor: executor.clone(), + done: Arc::new(AtomicBool::new(false)), + }; + + // Define the context over the generation + // TODO(asap): Do we really need so many shared-ownership? + let ctx = Box::new(GenerationContext { + sender: sender.clone(), + tokenizer, + tokens: vec![], + done: Arc::clone(&generation.done), + start: None, + queued: Instant::now(), + }); + + // We are leaking the context on-purpose to avoid the box being dropped while there are + // still computation ongoing + // TODO(asap): Can we achieve the same with an Arc> without the need to go unsafe? + let ctx_ = Box::leak(ctx); + + // Submit the request to the batcher + let request_id = span!(Level::DEBUG, "submit") + .in_scope(|| async { + let mut handle = executor.write().await; + let request_id = handle.pin_mut().submit( + &tokens, + top_k as i32, + top_p, + temperature, + repetition_penalty, + frequency_penalty, + seed, + ); + + request_id + }) + .await; + + while let Some(_) = generation.next().await { + let mut executor_w = executor.write().await; + let executor = executor_w.pin_mut(); + + span!(Level::DEBUG, "decode") + .in_scope(|| async { + unsafe { + executor.stream_tokens( + request_id, + ctx_, + |ctx: *mut GenerationContext, step: GenerationStep| { + let inner_ctx = &mut *ctx; + + // Update the timestamp at which the request started effectively + // Can be a bit off, would need to be before the callback, let's see + inner_ctx.start.get_or_insert(Instant::now()); + inner_ctx.done.store(step.is_final, Ordering::Relaxed); + + // Ensure we are not running into errors + let parcel = if !step.has_error { + // Insert the latest generated token to the tracker + inner_ctx.tokens.push(step.token_id); + + // Decode the token + let text = inner_ctx + .tokenizer + .decode(&[step.token_id], true) + .expect("Failed to decode token"); + + let special = inner_ctx + .tokenizer + .get_added_vocabulary() + .is_special_token(&text); + + // Create the structure holding the token + let token = Token { + id: step.token_id, + text, + logprob: step.log_prob, + special, + }; + + if step.is_final { + let generated_text = inner_ctx + .tokenizer + .decode(&inner_ctx.tokens, true) + .expect("Failed to decode generated_tokens"); + + Ok(InferStreamResponse::End { + token, + top_tokens: vec![], + generated_text: GeneratedText { + text: generated_text, + generated_tokens: inner_ctx.tokens.len() as u32, + finish_reason: FinishReason::EndOfSequenceToken, + seed: None, + }, + start: inner_ctx.start.unwrap_or(Instant::now()), + queued: inner_ctx.queued, + }) + } else { + Ok(InferStreamResponse::Intermediate { + token, + top_tokens: vec![], + }) + } + } else { + error!("Error caught while decoding: {}", &step.error_msg); + Err(InferError::GenerationError(step.error_msg)) + }; + + // Send the parcel to the client + inner_ctx + .sender + .send(parcel) + .expect("Failed to sent msg through the channel"); + }, + ); + } + }) + .await; + } + + // "Properly" free the shared context... + // TODO: clean that piece of sh** asap + unsafe { + let _ = Box::from_raw(ctx_); + } + }); + } +} + +#[async_trait] +impl Backend for TensorRtLlmBackend { + #[instrument(skip_all)] + fn schedule( + &self, + request: ValidGenerateRequest, + ) -> InferResult>> { + // Let's add a few more validation + let input = TensorRtLlmBackend::validate(&request)?; + + // Channel to stream the generated token as they come from the worker thread back to the transport layer + let (sender, receiver) = unbounded_channel(); + + // Unpack parameters + let params = &request.parameters; + + // Preprocess the inputs to send to TRTLLM backend + let encoding = self + .tokenizer + .encode(input.as_str(), true) + .map_err(|e| InferError::GenerationError(e.to_string()))?; + + // Generate the response + self.generate( + sender, + Vec::from(encoding.get_ids()), + params.top_k, + params.top_p, + params.temperature, + params.repetition_penalty, + params.frequency_penalty, + params.seed, + ); + + Ok(UnboundedReceiverStream::new(receiver)) + } + + async fn health(&self, _current_health: bool) -> bool { + true + } +} diff --git a/backends/trtllm/src/errors.rs b/backends/trtllm/src/errors.rs new file mode 100644 index 00000000..a672d2a4 --- /dev/null +++ b/backends/trtllm/src/errors.rs @@ -0,0 +1,15 @@ +use thiserror::Error; + +use text_generation_router::server; + +#[derive(Debug, Error)] +pub enum TensorRtLlmBackendError { + #[error("Tokenizer error: {0}")] + Tokenizer(String), + #[error("Argument validation error: {0}")] + ArgumentValidation(String), + #[error("WebServer error: {0}")] + WebServer(#[from] server::WebServerError), + #[error("Tokio runtime failed to start: {0}")] + Tokio(#[from] std::io::Error), +} diff --git a/backends/trtllm/src/ffi.cpp b/backends/trtllm/src/ffi.cpp new file mode 100644 index 00000000..d6317a68 --- /dev/null +++ b/backends/trtllm/src/ffi.cpp @@ -0,0 +1,84 @@ +// +// Created by mfuntowicz on 6/30/24. +// +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include "backends/trtllm/include/ffi.h" + + +huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl( + const std::string_view &engineFolder, + const std::string_view &executorWorker +) : TensorRtLlmBackend(engineFolder, executorWorker) {} + + +bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const { + return TensorRtLlmBackend::IsReady(); +} + +uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit( + rust::Slice tokens, int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty, + float_t frequency_penalty, uint64_t seed) { + + // This will copy all the items from the initial slice + std::vector tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end())); + return TensorRtLlmBackend::Submit( + std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed); +} + +size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens( + const uint64_t requestId, + huggingface::tgi::backends::GenerationContext *ctx, + rust::Fn callback) { + + size_t numTokens = 0; + for (const auto &item: Poll(requestId)) { + GenerationStep step; + if (!item.hasError()) { + SPDLOG_DEBUG("\tStreamTokens -> Decoding token..."); + const auto decoded = item.getResult(); + + const auto token = decoded.outputTokenIds[0][0]; + const auto isFinal = decoded.isFinal; + const auto logProb = decoded.logProbs.value()[0][0]; + + ++numTokens; + + SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal); + step = huggingface::tgi::backends::GenerationStep{ + static_cast(token), logProb, isFinal, false, std::move(std::string()) + }; + SPDLOG_DEBUG("\tStreamTokens -> Post callback"); + } else { + // TODO : Return rest::Result with error + const auto what = item.getErrorMsg(); + SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what); + step = huggingface::tgi::backends::GenerationStep{ + std::numeric_limits::max(), 0.0, true, true, std::move(what) + }; + } + + callback(std::move(ctx), std::move(step)); + } + + return numTokens; +} + +std::unique_ptr +huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) { + // Unconditionally call this to initialize and discover TRTLLM plugins + InitializeBackend(); + + const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end()); + const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end()); + return std::make_unique(std::move(enginePath), std::move(executorPath)); +} diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs new file mode 100644 index 00000000..1a804f88 --- /dev/null +++ b/backends/trtllm/src/lib.rs @@ -0,0 +1,78 @@ +pub use backend::{GenerationContext, TensorRtLlmBackend}; + +mod backend; +pub mod errors; + +#[cxx::bridge(namespace = "huggingface::tgi::backends")] +mod ffi { + + /// Struct used as shared type between rust and C++ to represent the result + /// of a single decoding iteration + pub struct GenerationStep { + token_id: u32, + log_prob: f32, + is_final: bool, + has_error: bool, + error_msg: String, + } + + extern "Rust" { + type GenerationContext; + } + + unsafe extern "C++" { + include!("backends/trtllm/src/ffi.cpp"); + + /// Represent an instance of the underlying TensorRT-LLM backend + type TensorRtLlmBackendImpl; + + /// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend + /// + /// # Arguments + /// + /// * `engine_folder`: Path to the folder containing all the TRTLLM engines + /// * `executor_worker`: Path to the TRTLLM executor worker + /// + /// returns: + /// + /// # Examples + /// + /// ``` + /// + /// ``` + #[rust_name = "create_tensorrt_llm_backend"] + fn CreateTensorRtLlmBackend( + engine_folder: &str, + executor_worker: &str, + ) -> UniquePtr; + + // #[rust_name = "is_ready"] + // fn IsReady(self: &TensorRtLlmBackendImpl) -> bool; + + #[rust_name = "num_responses_ready"] + fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize; + + #[rust_name = "submit"] + fn Submit( + self: Pin<&mut TensorRtLlmBackendImpl>, + tokens: &[u32], + top_k: i32, + top_p: f32, + temperature: f32, + repetition_penalty: f32, + frequency_penalty: f32, + seed: u64, + ) -> u64; + + #[rust_name = "stream_tokens"] + unsafe fn StreamTokens( + self: Pin<&mut TensorRtLlmBackendImpl>, + request_id: u64, + ctx: *mut GenerationContext, + cb: unsafe fn(*mut GenerationContext, GenerationStep), + ) -> usize; + + // #[rust_name = "shutdown"] + // fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>); + } +} diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs new file mode 100644 index 00000000..6d6ee146 --- /dev/null +++ b/backends/trtllm/src/main.rs @@ -0,0 +1,166 @@ +use std::collections::HashMap; +use std::path::PathBuf; + +use clap::Parser; +use tokenizers::{FromPretrainedParameters, Tokenizer}; + +use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; +use text_generation_backends_trtllm::TensorRtLlmBackend; +use text_generation_router::server; + +/// App Configuration +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + #[clap(default_value = "128", long, env)] + max_concurrent_requests: usize, + #[clap(default_value = "2", long, env)] + max_best_of: usize, + #[clap(default_value = "4", long, env)] + max_stop_sequences: usize, + #[clap(default_value = "5", long, env)] + max_top_n_tokens: u32, + #[clap(default_value = "1024", long, env)] + max_input_tokens: usize, + #[clap(default_value = "2048", long, env)] + max_total_tokens: usize, + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + #[clap(default_value = "3000", long, short, env)] + port: u16, + #[clap(long, env, required = true)] + tokenizer_name: String, + #[clap(long, env)] + tokenizer_config_path: Option, + #[clap(long, env)] + revision: Option, + #[clap(long, env)] + model_id: String, + #[clap(default_value = "2", long, env)] + validation_workers: usize, + #[clap(long, env)] + json_output: bool, + #[clap(long, env)] + otlp_endpoint: Option, + #[clap(default_value = "text-generation-inference.router", long, env)] + otlp_service_name: String, + #[clap(long, env)] + cors_allow_origin: Option>, + #[clap(long, env, default_value_t = false)] + messages_api_enabled: bool, + #[clap(default_value = "4", long, env)] + max_client_batch_size: usize, + #[clap(long, env)] + auth_token: Option, + #[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")] + executor_worker: PathBuf, +} + +#[tokio::main] +async fn main() -> Result<(), TensorRtLlmBackendError> { + // Get args + let args = Args::parse(); + // Pattern match configuration + let Args { + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + max_batch_prefill_tokens, + max_batch_total_tokens, + hostname, + port, + tokenizer_name, + tokenizer_config_path, + revision, + model_id, + validation_workers, + json_output, + otlp_endpoint, + otlp_service_name, + cors_allow_origin, + messages_api_enabled, + max_client_batch_size, + auth_token, + executor_worker, + } = args; + + // Launch Tokio runtime + text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output); + + // Validate args + if max_input_tokens >= max_total_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation( + "`max_input_tokens` must be < `max_total_tokens`".to_string(), + )); + } + if max_input_tokens as u32 > max_batch_prefill_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); + } + + if validation_workers == 0 { + return Err(TensorRtLlmBackendError::ArgumentValidation( + "`validation_workers` must be > 0".to_string(), + )); + } + + if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { + if max_batch_prefill_tokens > *max_batch_total_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); + } + if max_total_tokens as u32 > *max_batch_total_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); + } + } + + if !executor_worker.exists() { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!( + "`executor_work` specified path doesn't exists: {}", + executor_worker.display() + ))); + } + + // Run server + let tokenizer = Tokenizer::from_pretrained( + tokenizer_name.clone(), + Some(FromPretrainedParameters { + revision: revision.clone().unwrap_or(String::from("main")), + user_agent: HashMap::new(), + auth_token, + }), + ) + .map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?; + + let backend = TensorRtLlmBackend::new(tokenizer, model_id, executor_worker)?; + server::run( + backend, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + validation_workers, + None, + tokenizer_name, + tokenizer_config_path, + revision, + hostname, + port, + cors_allow_origin, + false, + None, + None, + messages_api_enabled, + true, + max_client_batch_size, + ) + .await?; + Ok(()) +} diff --git a/backends/trtllm/tests/infer_test.cpp b/backends/trtllm/tests/infer_test.cpp new file mode 100644 index 00000000..8520065a --- /dev/null +++ b/backends/trtllm/tests/infer_test.cpp @@ -0,0 +1,14 @@ +// +// Created by mfuntowicz on 7/2/24. +// +#include +#include +#include "../include/backend.h" + +TEST_CASE("Load TRTLLM Engine on the TGI Backend", "[trtllm][engine][load]") { + const auto engines = std::filesystem::path("/home/mfuntowicz/.cache/huggingface/assets/trtllm/0.11.0.dev2024062500/meta-llama--Meta-Llama-3-8B-Instruct/4090/engines/"); + const auto executor = std::filesystem::path("/home/mfuntowicz/Workspace/text-generation-inference/backends/trtllm/cmake-build-debug/cmake-build-debug/_deps/trtllm-src/cpp/tensorrt_llm/executor_worker/executorWorker"); + + spdlog::info("Loading config from: {}", absolute(engines).string()); + huggingface::tgi::backends::TensorRtLlmBackend backend(engines, executor); +} diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml new file mode 100644 index 00000000..5d9a140b --- /dev/null +++ b/backends/v3/Cargo.toml @@ -0,0 +1,66 @@ +[package] +name = "text-generation-router-v3" +description = "Text Generation Webserver" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[lib] +path = "src/lib.rs" + +[[bin]] +name = "text-generation-router" +path = "src/main.rs" + +[dependencies] +async-trait = "0.1.74" +async-stream = "0.3.5" +axum = { version = "0.7", features = ["json"] } +axum-tracing-opentelemetry = "0.16" +text-generation-router = { path = "../../router" } +clap = { version = "4.4.5", features = ["derive", "env"] } +grpc-metadata = { path = "../grpc-metadata" } +futures = "0.3.28" +hf-hub = { workspace = true } +jsonschema = { version = "0.17.1", features = ["draft202012"] } +metrics = { workspace = true } +metrics-exporter-prometheus = { workspace = true } +nohash-hasher = "0.2.0" +opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } +opentelemetry-otlp = "0.13.0" +rand = "0.8.5" +reqwest = { version = "0.11.20", features = [] } +serde = "1.0.188" +serde_json = "1.0.107" +thiserror = "1.0.48" +tokenizers = { workspace = true} +tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokio-stream = "0.1.14" +tower-http = { version = "0.5.1", features = ["cors"] } +tracing = "0.1.37" +tracing-opentelemetry = "0.21.0" +tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } +utoipa = { version = "4.2.0", features = ["axum_extras"] } +utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } +init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } +minijinja = { version = "2.0.2" } +minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } +futures-util = "0.3.30" +regex = "1.10.3" +once_cell = "1.19.0" +image = "0.25.1" +base64 = { workspace = true } +prost = "^0.12" +tonic = "^0.10" +tower = "^0.4" + +[build-dependencies] +tonic-build = "0.10.1" +prost-build = "0.12.1" + +[features] +default = ["ngrok"] +ngrok = ["text-generation-router/ngrok"] +google = ["text-generation-router/google"] +kserve = ["text-generation-router/kserve"] diff --git a/backends/v3/build.rs b/backends/v3/build.rs new file mode 100644 index 00000000..6d702d14 --- /dev/null +++ b/backends/v3/build.rs @@ -0,0 +1,19 @@ +use std::fs; + +fn main() -> Result<(), Box> { + println!("cargo:rerun-if-changed=../../proto/"); + + fs::create_dir_all("src/client/pb").unwrap_or(()); + let mut config = prost_build::Config::new(); + config.protoc_arg("--experimental_allow_proto3_optional"); + + tonic_build::configure() + .build_client(true) + .build_server(false) + .out_dir("src/client/pb") + .include_file("mod.rs") + .compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"]) + .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); + + Ok(()) +} diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs new file mode 100644 index 00000000..49e2bc8f --- /dev/null +++ b/backends/v3/src/backend.rs @@ -0,0 +1,501 @@ +use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient}; +/// Batching and inference logic +use crate::queue::{Entry, Queue}; +use async_trait::async_trait; +use nohash_hasher::IntMap; +use std::sync::Arc; +use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; +use text_generation_router::validation::ValidGenerateRequest; +use text_generation_router::{FinishReason, PrefillToken, Token}; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::{mpsc, Notify}; +use tokio::time::Instant; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::{info_span, instrument, Instrument, Span}; + +pub struct BackendV3 { + /// Request queue + queue: Queue, + /// Notify batcher on queue appends + batching_task_notifier: Arc, + /// Client clone, used for health checks to skip the queue + client: ShardedClient, +} + +impl BackendV3 { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + client: ShardedClient, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: u32, + max_waiting_tokens: usize, + max_batch_size: Option, + requires_padding: bool, + window_size: Option, + speculate: u32, + ) -> Self { + let queue = Queue::new( + requires_padding, + 16, + window_size, + speculate, + max_batch_total_tokens, + ); + let batching_task_notifier = Arc::new(Notify::new()); + + // Spawn batching background task that contains all the inference logic + tokio::spawn(batching_task( + client.clone(), + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + queue.clone(), + batching_task_notifier.clone(), + )); + + Self { + queue, + batching_task_notifier, + client, + } + } +} + +#[async_trait] +impl Backend for BackendV3 { + #[instrument(skip_all)] + fn schedule( + &self, + request: ValidGenerateRequest, + ) -> Result>, InferError> { + // MPSC channel to communicate with the background batching task + let (response_tx, response_rx) = mpsc::unbounded_channel(); + + // Append the request to the queue + self.queue.append(Entry { + request, + response_tx, + span: Span::current(), + temp_span: None, + queue_time: Instant::now(), + batch_time: None, + block_allocation: None, + }); + + // Notify the background task that we have a new entry in the queue that needs + // to be batched + self.batching_task_notifier.notify_one(); + + // Return stream + Ok(UnboundedReceiverStream::new(response_rx)) + } + + async fn health(&self, current_health: bool) -> bool { + if current_health { + // Generation is healthy, we only check that the shards can allocate on device + self.client.device_health().await + } else { + self.client.model_health().await + } + .is_ok() + } +} + +/// Batching logic +/// Will be launched in a background Tokio task +/// +/// Batches requests and sends them to the inference server +#[allow(clippy::too_many_arguments)] +pub(crate) async fn batching_task( + mut client: ShardedClient, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: u32, + max_waiting_tokens: usize, + max_batch_size: Option, + queue: Queue, + notifier: Arc, +) { + // Infinite loop + loop { + // Wait for a notification from the Infer struct + notifier.notified().await; + + // Get the next batch from the queue + // This batch might be smaller than the maximum batch size if there are not enough requests + // waiting in the queue + while let Some((mut entries, batch, span)) = queue + .next_batch( + None, + max_batch_size, + max_batch_prefill_tokens, + max_batch_total_tokens, + ) + .await + { + let mut cached_batch = prefill(&mut client, batch, &mut entries) + .instrument(span) + .await; + let mut waiting_tokens = 1; + + // We loop until we do not receive any cached batch from the inference server (== until + // all requests have met their stopping criteria) + while let Some(batch) = cached_batch { + // Get current batch info + let batch_size = batch.size; + let batch_max_tokens = batch.max_tokens; + let mut batches = vec![batch]; + metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); + metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); + + let min_size = if waiting_tokens >= max_waiting_tokens { + // If we didn't onboard any new requests since >= max_waiting_tokens, we try + // to add a new batch even though its size might be small + None + } else { + // Minimum batch size + Some((batch_size as f32 * waiting_served_ratio).floor() as usize) + }; + + let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); + let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize); + + // Try to get a new batch + if let Some((mut new_entries, new_batch, span)) = queue + .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) + .await + { + // Tracking metrics + if min_size.is_some() { + metrics::counter!("tgi_batch_concat", "reason" => "backpressure") + .increment(1); + } else { + metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") + .increment(1); + } + + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to add the info that this entry is waiting + // because a new batch is being computed + let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); + // Add relationships + span.follows_from(&entry_waiting_span); + entry_waiting_span.follows_from(&span); + // Update entry + entry.temp_span = Some(entry_waiting_span); + }); + + // Generate one token for this new batch to have the attention past in cache + let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) + .instrument(span) + .await; + // Reset waiting counter + waiting_tokens = 1; + // Extend current batch with the new batch + if let Some(new_cached_batch) = new_cached_batch { + entries.extend(new_entries); + batches.push(new_cached_batch); + } + } + + // Create span for this batch to add context to inference calls + let next_batch_size = entries.len(); + let next_batch_span = + info_span!(parent: None, "batch", batch_size = next_batch_size); + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to link the batch back to this entry + let entry_batch_span = info_span!(parent: &entry.span, "infer"); + // Add relationships + next_batch_span.follows_from(&entry_batch_span); + entry_batch_span.follows_from(&next_batch_span); + // Update entry + entry.temp_span = Some(entry_batch_span); + }); + + cached_batch = decode(&mut client, batches, &mut entries) + .instrument(next_batch_span) + .await; + waiting_tokens += 1; + } + metrics::gauge!("tgi_batch_current_size").set(0.0); + metrics::gauge!("tgi_batch_current_max_tokens").set(0.0); + } + } +} + +#[instrument(skip_all)] +async fn prefill( + client: &mut ShardedClient, + batch: Batch, + entries: &mut IntMap, +) -> Option { + let start_time = Instant::now(); + let batch_id = batch.id; + metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); + + match client.prefill(batch).await { + Ok((generations, next_batch, timings)) => { + let start_filtering_time = Instant::now(); + // Send generated tokens and filter stopped entries + filter_send_generations(generations, entries); + + // Filter next batch and remove requests that were stopped + let next_batch = filter_batch(client, next_batch, entries).await; + + metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + let _ = client.clear_cache(Some(batch_id)).await; + send_errors(err, entries); + metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1); + None + } + } +} + +#[instrument(skip_all)] +async fn decode( + client: &mut ShardedClient, + batches: Vec, + entries: &mut IntMap, +) -> Option { + let start_time = Instant::now(); + let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); + metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); + + match client.decode(batches).await { + Ok((generations, next_batch, timings)) => { + let start_filtering_time = Instant::now(); + // Send generated tokens and filter stopped entries + filter_send_generations(generations, entries); + + // Filter next batch and remove requests that were stopped + let next_batch = filter_batch(client, next_batch, entries).await; + + if let Some(concat_duration) = timings.concat { + metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") + .record(concat_duration.as_secs_f64()); + } + metrics::histogram!("tgi_batch_forward_duration", "method" => "decode") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "decode") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "decode") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "decode") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + for id in batch_ids { + let _ = client.clear_cache(Some(id)).await; + } + send_errors(err, entries); + metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1); + None + } + } +} + +/// Filter a `batch` and remove all requests not present in `entries` +#[instrument(skip_all)] +async fn filter_batch( + client: &mut ShardedClient, + next_batch: Option, + entries: &IntMap, +) -> Option { + let mut batch = next_batch?; + + // No need to filter + if batch.size as usize == entries.len() { + return Some(batch); + } + + let id = batch.id; + + // Retain only requests that are still in entries + batch.request_ids.retain(|id| entries.contains_key(id)); + + if batch.request_ids.is_empty() { + // All requests have been filtered out + // Next batch is now empty + // Clear it from the Python shards cache + // We unwrap here as we need to panic since we cannot recover if this method fails + client.clear_cache(Some(id)).await.unwrap(); + None + } else { + // Filter Python shard cache + // We unwrap here as we need to panic since we cannot recover if this method fails + client.filter_batch(id, batch.request_ids).await.unwrap() + } +} + +/// Send one or multiple `InferStreamResponse` to Infer for all `entries` +/// and filter entries +#[instrument(skip_all)] +fn filter_send_generations(generations: Vec, entries: &mut IntMap) { + generations.into_iter().for_each(|generation| { + let id = generation.request_id; + // Get entry + // We can `expect` here as the request id should always be in the entries + let entry = entries + .get(&id) + .expect("ID not found in entries. This is a bug."); + + // Create and enter a span to link this function back to the entry + let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); + // Send generation responses back to the infer task + // If the receive an error from the Flume channel, it means that the client dropped the + // request and we need to stop generating hence why we unwrap_or(true) + let stopped = send_responses(generation, entry).map_err(|err| { + tracing::error!("Entry response channel error."); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); + err + }).unwrap_or(true); + if stopped { + entries.remove(&id).expect("ID not found in entries. This is a bug."); + } + }); +} + +/// Send responses through the `entry` response channel +fn send_responses( + generation: Generation, + entry: &Entry, +) -> Result>>> { + // Return directly if the channel is disconnected + if entry.response_tx.is_closed() { + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); + return Ok(true); + } + + let mut stopped = false; + + if let Some(prefill_tokens) = generation.prefill_tokens { + // Create Token objects + // We do that here instead of in the Python code as Rust for loops are faster + let prefill_tokens = prefill_tokens + .ids + .into_iter() + .zip(prefill_tokens.logprobs) + .zip(prefill_tokens.texts) + .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) + .collect(); + + // Send message + entry + .response_tx + .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; + } + + // Create last Token + let tokens_ = generation.tokens.expect("Non empty tokens in generation"); + let n = tokens_.ids.len(); + metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64); + let mut iterator = tokens_ + .ids + .into_iter() + .zip(tokens_.logprobs) + .zip(tokens_.texts) + .zip(tokens_.is_special) + .enumerate() + .peekable(); + while let Some((i, (((id, logprob), text), special))) = iterator.next() { + let token = Token { + id, + text, + logprob, + special, + }; + let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) { + top_tokens_ + .ids + .iter() + .zip(top_tokens_.logprobs.iter()) + .zip(top_tokens_.texts.iter()) + .zip(top_tokens_.is_special.iter()) + .map(|(((&id, &logprob), text), &special)| Token { + id, + text: text.to_string(), + logprob, + special, + }) + .collect() + } else { + vec![] + }; + match (&generation.generated_text, iterator.peek()) { + (Some(generated_text), None) => { + // Generation has ended + stopped = true; + // Send message + entry.response_tx.send(Ok(InferStreamResponse::End { + token, + top_tokens, + generated_text: GeneratedText::from(generated_text.clone()), + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + }))?; + } + _ => { + // Send message + entry + .response_tx + .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; + } + } + } + + Ok(stopped) +} + +/// Send errors to Infer for all `entries` +#[instrument(skip_all)] +fn send_errors(error: ClientError, entries: &mut IntMap) { + entries.drain().for_each(|(_, entry)| { + // Create and enter a span to link this function back to the entry + let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); + let err = InferError::GenerationError(error.to_string()); + metrics::counter!("tgi_request_failure", "err" => "generation").increment(1); + tracing::error!("{err}"); + + // unwrap_or is valid here as we don't care if the receiver is gone. + entry + .response_tx + .send(Err(err)) + .unwrap_or(()); + }); +} + +impl From for GeneratedText { + fn from(value: crate::client::GeneratedText) -> Self { + let v3_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap(); + let finish_reason = match v3_finish_reason { + crate::client::FinishReason::Length => FinishReason::Length, + crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken, + crate::client::FinishReason::StopSequence => FinishReason::StopSequence, + }; + + Self { + text: value.text, + generated_tokens: value.generated_tokens, + finish_reason, + seed: value.seed, + } + } +} diff --git a/router/src/infer/v3/block_allocator.rs b/backends/v3/src/block_allocator.rs similarity index 100% rename from router/src/infer/v3/block_allocator.rs rename to backends/v3/src/block_allocator.rs diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs new file mode 100644 index 00000000..c407687b --- /dev/null +++ b/backends/v3/src/client/grpc_client.rs @@ -0,0 +1,284 @@ +/// Single shard Client +use crate::client::{pb, Chunk}; +use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64}; +use base64::engine::general_purpose::STANDARD; +use base64::Engine; +use grpc_metadata::InjectTelemetryContext; +use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient; +use pb::generate::v3::*; +use std::cmp::min; +use std::time::Duration; +use tonic::transport::{Channel, Uri}; +use tracing::instrument; + +/// Text Generation Inference gRPC client +#[derive(Debug, Clone)] +pub struct Client { + stub: TextGenerationServiceClient, +} + +impl Client { + /// Returns a client connected to the given url + #[allow(dead_code)] + pub async fn connect(uri: Uri) -> Result { + let channel = Channel::builder(uri).connect().await?; + + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) + } + + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { + let channel = Channel::from_shared("http://[::]:50051".to_string()) + .unwrap() + .connect_with_connector(tower::service_fn(move |_: Uri| { + tokio::net::UnixStream::connect(path.clone()) + })) + .await?; + + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) + } + + /// Returns a list of uris or unix sockets of all shards + #[instrument(skip(self))] + pub async fn service_discovery(&mut self) -> Result> { + let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context(); + let response = self.stub.service_discovery(request).await.map_err(|_| { + ClientError::Connection("Server does not support v3 interface".to_string()) + })?; + let urls = response + .into_inner() + .urls + .into_iter() + // Remove unix socket prefix + .map(|url| match url.strip_prefix("unix://") { + None => url, + Some(stripped_url) => stripped_url.to_string(), + }) + .collect(); + Ok(urls) + } + + /// Get model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let request = tonic::Request::new(InfoRequest {}).inject_context(); + let response = self.stub.info(request).await?.into_inner(); + Ok(response) + } + + /// Get model health + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let request = tonic::Request::new(HealthRequest {}).inject_context(); + let response = self.stub.health(request).await?.into_inner(); + Ok(response) + } + + /// Clear the past generations cache + #[instrument(skip(self))] + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { + let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context(); + self.stub.clear_cache(request).await?; + Ok(()) + } + + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + request_ids: Vec, + ) -> Result> { + let request = tonic::Request::new(FilterBatchRequest { + batch_id, + request_ids, + }) + .inject_context(); + let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); + Ok(filtered_batch.batch) + } + + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip_all)] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + max_batch_size: Option, + ) -> Result> { + let mut n_tokens = 0; + let mut requests = Vec::new(); + // Create requests + while n_tokens < max_prefill_tokens { + let truncate = min(max_input_length, max_prefill_tokens - n_tokens); + + let mut input_chunks = Vec::new(); + input_chunks + .push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into()); + if n_tokens == 0 { + input_chunks.push( + Chunk::Image(Image { + // Safe unwrap, because we control the data. + data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(), + mimetype: "image/jpeg;base64".to_string(), + }) + .into(), + ); + } + + // Send stringly-typed inputs for compatibility for backends that haven't + // been updated to support chunks. + + let mut inputs = String::new(); + inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); + if n_tokens == 0 { + // 1 request is enough to test vision heads. + // Sending images on other queries messes up easily with truncation. + inputs.push_str(&format!( + "![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})", + )); + } + + requests.push(Request { + id: 0, + inputs, + input_chunks: Some(Input { + chunks: input_chunks, + }), + // We truncate the input on the server side to be sure that it has the correct size + truncate, + // Blocks and slots will be set on the server side if we use paged attention + blocks: vec![], + slots: vec![], + // Set sampling parameters to also take these ops into account in the max memory + parameters: Some(NextTokenChooserParameters { + temperature: 0.9, + top_k: 10, + top_p: 0.9, + typical_p: 0.9, + do_sample: false, + seed: 0, + repetition_penalty: 1.2, + frequency_penalty: 0.1, + watermark: true, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: max_total_tokens - truncate, + stop_sequences: vec![], + ignore_eos_token: true, + }), + prefill_logprobs: true, + top_n_tokens: 20, + adapter_id: None, + }); + n_tokens += max_input_length; + + // Check max_batch_size + if Some(requests.len()) == max_batch_size { + break; + } + } + + let batch = Batch { + id: 0, + size: requests.len() as u32, + requests, + max_tokens: max_input_length, + max_blocks: 0, + }; + + let request = tonic::Request::new(WarmupRequest { + batch: Some(batch), + max_input_length, + max_prefill_tokens, + max_total_tokens, + }) + .inject_context(); + let response = self.stub.warmup(request).await?.into_inner(); + Ok(response.max_supported_total_tokens) + } + + /// Generate one token for each request in the given batch + /// + /// Returns Generation for each request in batch + /// and the next cached batch + #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))] + pub async fn prefill( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option, PrefillTimings)> { + let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); + let response = self.stub.prefill(request).await?.into_inner(); + Ok(( + response.generations, + response.batch, + PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns), + )) + } + + /// Generate one token for each request in the given cached batches + /// + /// Returns Generation for each request in batches + /// and the next cached batch + #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::()))] + pub async fn decode( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option, DecodeTimings)> { + let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); + let response = self.stub.decode(request).await?.into_inner(); + Ok(( + response.generations, + response.batch, + DecodeTimings::new( + response.concat_ns, + response.forward_ns, + response.decode_ns, + response.total_ns, + ), + )) + } +} + +pub struct PrefillTimings { + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl PrefillTimings { + fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } + } +} + +pub struct DecodeTimings { + pub concat: Option, + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl DecodeTimings { + fn new(concat_ns: Option, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + concat: concat_ns.map(Duration::from_nanos), + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } + } +} diff --git a/backends/v3/src/client/mod.rs b/backends/v3/src/client/mod.rs new file mode 100644 index 00000000..755431f4 --- /dev/null +++ b/backends/v3/src/client/mod.rs @@ -0,0 +1,76 @@ +//! Text Generation gRPC client library + +use async_trait::async_trait; +use thiserror::Error; +use tonic::transport; +use tonic::Status; + +#[allow(clippy::derive_partial_eq_without_eq)] +mod pb; + +mod grpc_client; +mod sharded_client; + +pub use grpc_client::Client; +pub use pb::generate::v3::{ + input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, + HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, + StoppingCriteriaParameters, +}; +pub use sharded_client::ShardedClient; + +#[async_trait] +pub trait Health { + /// Check if a generate server is healthy by asking it to allocate a tensor on device + async fn device_health(&self) -> Result<()>; + + /// Check if a generate server is healthy by doing a forward pass. + /// EXPENSIVE + async fn model_health(&self) -> Result<()>; +} + +#[derive(Debug)] +pub struct ShardInfo { + pub requires_padding: bool, + pub dtype: String, + pub device_type: String, + pub window_size: Option, + pub speculate: u32, +} + +#[derive(Error, Debug, Clone)] +pub enum ClientError { + #[error("Could not connect to Text Generation server: {0}")] + Connection(String), + #[error("Server error: {0}")] + Generation(String), + #[error("Sharded results are empty")] + EmptyResults, +} + +impl From for ClientError { + fn from(err: Status) -> Self { + let err = Self::Generation(err.message().to_string()); + tracing::error!("{err}"); + err + } +} + +impl From for ClientError { + fn from(err: transport::Error) -> Self { + let err = Self::Connection(err.to_string()); + tracing::error!("{err}"); + err + } +} + +// Small convenience re-wrapping of `Chunk`. +impl From for InputChunk { + fn from(chunk: Chunk) -> Self { + InputChunk { chunk: Some(chunk) } + } +} + +static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; + +pub type Result = std::result::Result; diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs new file mode 100644 index 00000000..afb13cdc --- /dev/null +++ b/backends/v3/src/client/sharded_client.rs @@ -0,0 +1,260 @@ +use crate::client::{ClientError, Result}; +/// Multi shard Client +use crate::client::{Health, ShardInfo}; + +use crate::client::grpc_client::{DecodeTimings, PrefillTimings}; +use crate::client::{ + Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, +}; +use crate::client::{Chunk, InfoResponse, Input}; +use async_trait::async_trait; +use futures::future::join_all; +use tonic::transport::Uri; +use tracing::instrument; + +#[derive(Debug, Clone)] +/// Text Generation Inference gRPC multi client +pub struct ShardedClient { + clients: Vec, +} + +impl ShardedClient { + fn new(clients: Vec) -> Self { + Self { clients } + } + + /// Create a new ShardedClient from a master client. The master client will communicate with + /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method. + async fn from_master_client(mut master_client: Client) -> Result { + // Get all uris/unix sockets from the master client + let uris = master_client.service_discovery().await?; + let futures = uris.into_iter().map(Client::connect_uds); + let clients: Result> = join_all(futures).await.into_iter().collect(); + Ok(Self::new(clients?)) + } + + /// Returns a client connected to the given uri + #[allow(dead_code)] + pub async fn connect(uri: Uri) -> Result { + let master_client = Client::connect(uri).await?; + Self::from_master_client(master_client).await + } + + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { + let master_client = Client::connect_uds(path).await?; + Self::from_master_client(master_client).await + } + + /// Get the model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.info()) + .collect(); + join_all(futures).await.pop().unwrap().map(ShardInfo::from) + } + + /// GRPC health check + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.health()) + .collect(); + join_all(futures).await.pop().unwrap() + } + + /// Clear the past generations cache + #[instrument(skip(self))] + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.clear_cache(batch_id)) + .collect(); + join_all(futures).await.into_iter().collect() + } + + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + request_ids: Vec, + ) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone()))) + .collect(); + // all shards return the same message + join_all(futures).await.pop().unwrap() + } + + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip(self))] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + max_batch_size: Option, + ) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| { + Box::pin(client.warmup( + max_input_length, + max_prefill_tokens, + max_total_tokens, + max_batch_size, + )) + }) + .collect(); + // Take the minimum value + let results = join_all(futures) + .await + .into_iter() + .collect::>>>()?; + Ok(results.into_iter().flatten().min()) + } + + /// Generate one token for each request in the given batch + /// + /// Returns Generation for each request in batch + /// and the next cached batch + #[instrument(skip_all, fields(id = & batch.id, size = & batch.size))] + pub async fn prefill( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option, PrefillTimings)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.prefill(batch.clone()))) + .collect(); + #[allow(clippy::type_complexity)] + let results: Result, Option, PrefillTimings)>> = + join_all(futures).await.into_iter().collect(); + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) + } + + /// Generate one token for each request in the given cached batches + /// + /// Returns Generation for each request in batches + /// and the next cached batch + #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))] + pub async fn decode( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option, DecodeTimings)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.decode(batches.clone()))) + .collect(); + #[allow(clippy::type_complexity)] + let results: Result, Option, DecodeTimings)>> = + join_all(futures).await.into_iter().collect(); + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) + } +} + +impl From for ShardInfo { + fn from(value: InfoResponse) -> Self { + Self { + requires_padding: value.requires_padding, + dtype: value.dtype, + device_type: value.device_type, + window_size: value.window_size, + speculate: value.speculate, + } + } +} + +#[async_trait] +impl Health for ShardedClient { + async fn device_health(&self) -> Result<()> { + self.clone().health().await?; + Ok(()) + } + + async fn model_health(&self) -> Result<()> { + // Dummy batch of 1 token and 1 generated token + let liveness_request = Request { + id: u64::MAX, + inputs: "liveness".to_string(), + input_chunks: Some(Input { + chunks: vec![Chunk::Text("liveness".into()).into()], + }), + truncate: 10, + prefill_logprobs: false, + parameters: Some(NextTokenChooserParameters { + temperature: 1.0, + top_k: 0, + top_p: 1.0, + typical_p: 1.0, + do_sample: false, + seed: 0, + repetition_penalty: 1.0, + frequency_penalty: 0.0, + watermark: false, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: 1, + stop_sequences: vec![], + ignore_eos_token: false, + }), + top_n_tokens: 0, + // Block 0 is reserved for health checks + blocks: vec![0], + slots: (0..16).collect(), + adapter_id: None, + }; + let batch = Batch { + id: u64::MAX, + requests: vec![liveness_request], + size: 1, + max_tokens: 2, + max_blocks: 1, + }; + self.clone().prefill(batch).await?; + Ok(()) + } +} diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs new file mode 100644 index 00000000..a6f89169 --- /dev/null +++ b/backends/v3/src/lib.rs @@ -0,0 +1,142 @@ +mod backend; +mod block_allocator; +mod client; +mod queue; + +use crate::client::{ClientError, ShardedClient}; +pub(crate) use backend::BackendV3; +use serde::Serialize; +use thiserror::Error; +use utoipa::ToSchema; + +#[derive(Clone, Debug, Serialize, ToSchema)] +pub struct BackendInfo { + /// Mandatory + #[schema(example = "cuda")] + pub model_device_type: String, + #[schema(example = "torch.float16")] + pub model_dtype: String, + + /// Backend parameters + #[schema(example = "1")] + pub speculate: usize, + #[schema(example = "1.2")] + pub waiting_served_ratio: f32, + #[schema(example = "32000")] + pub max_batch_total_tokens: u32, + #[schema(example = "20")] + pub max_waiting_tokens: usize, + #[schema(nullable = true, example = "null")] + pub max_batch_size: Option, +} + +#[allow(clippy::too_many_arguments)] +pub async fn connect_backend( + max_input_tokens: usize, + max_total_tokens: usize, + master_shard_uds_path: String, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: Option, + max_waiting_tokens: usize, + max_batch_size: Option, +) -> Result<(BackendV3, BackendInfo), V3Error> { + // Helper function + let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option| { + match max_supported_batch_total_tokens { + // Older models do not support automatic max-batch-total-tokens + None => { + let max_batch_total_tokens = max_batch_total_tokens + .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))); + tracing::warn!("Model does not support automatic max batch total tokens"); + Ok(max_batch_total_tokens) + } + // Flash attention models return their max supported total tokens + Some(max_supported_batch_total_tokens) => { + // Warn if user added his own max-batch-total-tokens as we will ignore it + if max_batch_total_tokens.is_some() { + tracing::warn!( + "`--max-batch-total-tokens` is deprecated for Flash \ + Attention models." + ); + tracing::warn!( + "Inferred max batch total tokens: {max_supported_batch_total_tokens}" + ); + } + if max_total_tokens as u32 > max_supported_batch_total_tokens { + return Err(V3Error::NotEnoughMemory(max_total_tokens)); + } + + Ok(max_supported_batch_total_tokens) + } + } + }; + + let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) + .await + .map_err(V3Error::Connection)?; + + // server is running on v3 + // Clear the cache; useful if the webserver rebooted + sharded_client + .clear_cache(None) + .await + .map_err(V3Error::Cache)?; + // Get info from the shard + let shard_info = sharded_client.info().await.map_err(V3Error::Info)?; + + // Warmup model + tracing::info!("Warming up model"); + let max_batch_total_tokens = check_max_batch_total_tokens( + sharded_client + .warmup( + max_input_tokens as u32, + max_batch_prefill_tokens, + max_total_tokens as u32, + max_batch_size, + ) + .await + .map_err(V3Error::Warmup)?, + )?; + tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); + + let backend_info = BackendInfo { + waiting_served_ratio, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + model_device_type: shard_info.device_type.clone(), + model_dtype: shard_info.dtype.clone(), + speculate: shard_info.speculate as usize, + }; + + let backend = BackendV3::new( + sharded_client, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + shard_info.requires_padding, + shard_info.window_size, + shard_info.speculate, + ); + + tracing::info!("Using backend V3"); + + Ok((backend, backend_info)) +} + +#[derive(Debug, Error)] +pub enum V3Error { + #[error("Unable to clear the Python model shards cache: {0}")] + Cache(ClientError), + #[error("Unable to connect to the Python model shards: {0}")] + Connection(ClientError), + #[error("Unable to get the Python model shards info: {0}")] + Info(ClientError), + #[error("Unable to warmup the Python model shards: {0}")] + Warmup(ClientError), + #[error("Not enough memory to handle `max_total_tokens={0}`")] + NotEnoughMemory(usize), +} diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs new file mode 100644 index 00000000..ef10514f --- /dev/null +++ b/backends/v3/src/main.rs @@ -0,0 +1,208 @@ +use clap::{Parser, Subcommand}; +use text_generation_router::server; +use text_generation_router_v3::{connect_backend, V3Error}; +use thiserror::Error; + +/// App Configuration +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + #[command(subcommand)] + command: Option, + + #[clap(default_value = "128", long, env)] + max_concurrent_requests: usize, + #[clap(default_value = "2", long, env)] + max_best_of: usize, + #[clap(default_value = "4", long, env)] + max_stop_sequences: usize, + #[clap(default_value = "5", long, env)] + max_top_n_tokens: u32, + #[clap(default_value = "1024", long, env)] + max_input_tokens: usize, + #[clap(default_value = "2048", long, env)] + max_total_tokens: usize, + #[clap(default_value = "1.2", long, env)] + waiting_served_ratio: f32, + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, + #[clap(default_value = "20", long, env)] + max_waiting_tokens: usize, + #[clap(long, env)] + max_batch_size: Option, + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + #[clap(default_value = "3000", long, short, env)] + port: u16, + #[clap(default_value = "/tmp/text-generation-server-0", long, env)] + master_shard_uds_path: String, + #[clap(default_value = "bigscience/bloom", long, env)] + tokenizer_name: String, + #[clap(long, env)] + tokenizer_config_path: Option, + #[clap(long, env)] + revision: Option, + #[clap(default_value = "2", long, env)] + validation_workers: usize, + #[clap(long, env)] + api_key: Option, + #[clap(long, env)] + json_output: bool, + #[clap(long, env)] + otlp_endpoint: Option, + #[clap(default_value = "text-generation-inference.router", long, env)] + otlp_service_name: String, + #[clap(long, env)] + cors_allow_origin: Option>, + #[clap(long, env)] + ngrok: bool, + #[clap(long, env)] + ngrok_authtoken: Option, + #[clap(long, env)] + ngrok_edge: Option, + #[clap(long, env, default_value_t = false)] + messages_api_enabled: bool, + #[clap(long, env, default_value_t = false)] + disable_grammar_support: bool, + #[clap(default_value = "4", long, env)] + max_client_batch_size: usize, + #[clap(long, env, default_value_t)] + disable_usage_stats: bool, + #[clap(long, env, default_value_t)] + disable_crash_reports: bool, +} + +#[derive(Debug, Subcommand)] +enum Commands { + PrintSchema, +} + +#[tokio::main] +async fn main() -> Result<(), RouterError> { + // Get args + let args = Args::parse(); + // Pattern match configuration + let Args { + command, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + hostname, + port, + master_shard_uds_path, + tokenizer_name, + tokenizer_config_path, + revision, + validation_workers, + api_key, + json_output, + otlp_endpoint, + otlp_service_name, + cors_allow_origin, + ngrok, + ngrok_authtoken, + ngrok_edge, + messages_api_enabled, + disable_grammar_support, + disable_usage_stats, + disable_crash_reports, + max_client_batch_size, + } = args; + + if let Some(Commands::PrintSchema) = command { + use utoipa::OpenApi; + let api_doc = text_generation_router::server::ApiDoc::openapi(); + let api_doc = serde_json::to_string_pretty(&api_doc).unwrap(); + println!("{}", api_doc); + std::process::exit(0); + }; + text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output); + + // Validate args + if max_input_tokens >= max_total_tokens { + return Err(RouterError::ArgumentValidation( + "`max_input_tokens` must be < `max_total_tokens`".to_string(), + )); + } + if max_input_tokens as u32 > max_batch_prefill_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); + } + + if validation_workers == 0 { + return Err(RouterError::ArgumentValidation( + "`validation_workers` must be > 0".to_string(), + )); + } + + if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { + if max_batch_prefill_tokens > *max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); + } + if max_total_tokens as u32 > *max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); + } + } + + let (backend, _backend_info) = connect_backend( + max_input_tokens, + max_total_tokens, + master_shard_uds_path, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + ) + .await?; + + // Run server + server::run( + backend, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + validation_workers, + api_key, + tokenizer_name, + tokenizer_config_path, + revision, + hostname, + port, + cors_allow_origin, + ngrok, + ngrok_authtoken, + ngrok_edge, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + disable_usage_stats, + disable_crash_reports, + ) + .await?; + Ok(()) +} + +#[derive(Debug, Error)] +enum RouterError { + #[error("Argument validation error: {0}")] + ArgumentValidation(String), + #[error("Backend failed: {0}")] + Backend(#[from] V3Error), + #[error("WebServer error: {0}")] + WebServer(#[from] server::WebServerError), + #[error("Tokio runtime failed to start: {0}")] + Tokio(#[from] std::io::Error), +} diff --git a/router/src/infer/v3/queue.rs b/backends/v3/src/queue.rs similarity index 95% rename from router/src/infer/v3/queue.rs rename to backends/v3/src/queue.rs index 894d9cab..9427bd60 100644 --- a/router/src/infer/v3/queue.rs +++ b/backends/v3/src/queue.rs @@ -1,17 +1,17 @@ -use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator}; -use crate::infer::InferError; -use crate::infer::InferStreamResponse; -use crate::validation::{ - ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, +use crate::block_allocator::{BlockAllocation, BlockAllocator}; +use crate::client; +use crate::client::{ + Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::{max, min}; use std::collections::VecDeque; -use text_generation_client::v3::{ - Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, +use text_generation_router::infer::InferError; +use text_generation_router::infer::InferStreamResponse; +use text_generation_router::validation::{ + Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, + ValidStoppingParameters, }; -use text_generation_client::ChunksToString; -use text_generation_client::Input; use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; use tracing::{info_span, instrument, Instrument, Span}; @@ -337,8 +337,22 @@ impl State { batch_requests.push(Request { id, prefill_logprobs: entry.request.decoder_input_details, - input_chunks: Some(Input { - chunks: entry.request.inputs.clone(), + input_chunks: Some(client::Input { + chunks: entry + .request + .inputs + .clone() + .into_iter() + .map(|c| client::InputChunk { + chunk: Some(match c { + Chunk::Text(text) => client::Chunk::Text(text), + Chunk::Image(image) => client::Chunk::Image(client::Image { + data: image.data, + mimetype: image.mimetype, + }), + }), + }) + .collect(), }), inputs: entry.request.inputs.chunks_to_string(), truncate: entry.request.truncate, diff --git a/benchmark/Cargo.toml b/benchmark/Cargo.toml index 756460e0..f82659c9 100644 --- a/benchmark/Cargo.toml +++ b/benchmark/Cargo.toml @@ -21,7 +21,7 @@ float-ord = "0.3.2" serde = {version = "1.0.188", features = ["derive"]} serde_json = "1.0" tabled = "0.14.0" -text-generation-client = { path = "../router/client" } +text-generation-client = { path = "../backends/client" } thiserror = "1.0.48" tokenizers = { workspace = true } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] } diff --git a/docs/openapi.json b/docs/openapi.json index db163ca0..ed9b0b96 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1580,16 +1580,11 @@ "type": "object", "required": [ "model_id", - "model_dtype", - "model_device_type", "max_concurrent_requests", "max_best_of", "max_stop_sequences", "max_input_tokens", "max_total_tokens", - "waiting_served_ratio", - "max_batch_total_tokens", - "max_waiting_tokens", "validation_workers", "max_client_batch_size", "router", @@ -1601,18 +1596,6 @@ "example": "null", "nullable": true }, - "max_batch_size": { - "type": "integer", - "example": "null", - "nullable": true, - "minimum": 0 - }, - "max_batch_total_tokens": { - "type": "integer", - "format": "int32", - "example": "32000", - "minimum": 0 - }, "max_best_of": { "type": "integer", "example": "2", @@ -1644,19 +1627,6 @@ "example": "2048", "minimum": 0 }, - "max_waiting_tokens": { - "type": "integer", - "example": "20", - "minimum": 0 - }, - "model_device_type": { - "type": "string", - "example": "cuda" - }, - "model_dtype": { - "type": "string", - "example": "torch.float16" - }, "model_id": { "type": "string", "description": "Model info", @@ -1690,11 +1660,6 @@ "version": { "type": "string", "example": "0.5.0" - }, - "waiting_served_ratio": { - "type": "number", - "format": "float", - "example": "1.2" } } }, diff --git a/router/Cargo.toml b/router/Cargo.toml index 0fc700a0..1be74546 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -7,25 +7,18 @@ edition.workspace = true authors.workspace = true homepage.workspace = true -[lib] -path = "src/lib.rs" - -[[bin]] -name = "text-generation-router" -path = "src/main.rs" - [dependencies] +async-trait = "0.1.74" async-stream = "0.3.5" axum = { version = "0.7", features = ["json"] } axum-tracing-opentelemetry = "0.16" -text-generation-client = { path = "client" } clap = { version = "4.4.5", features = ["derive", "env"] } futures = "0.3.28" hf-hub = { workspace = true } itertools = "0.10" jsonschema = { version = "0.17.1", features = ["draft202012"] } -metrics = "0.23.0" -metrics-exporter-prometheus = { version = "0.15.1", features = [] } +metrics = { workspace = true } +metrics-exporter-prometheus = { workspace = true } nohash-hasher = "0.2.0" opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } opentelemetry-otlp = "0.13.0" @@ -55,6 +48,7 @@ base64 = { workspace = true } sysinfo = "0.30.13" uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] } csv = "1.3.0" +ureq = "=2.9" [build-dependencies] diff --git a/router/client/src/v2/pb/.gitignore b/router/client/src/v2/pb/.gitignore deleted file mode 100644 index 72e8ffc0..00000000 --- a/router/client/src/v2/pb/.gitignore +++ /dev/null @@ -1 +0,0 @@ -* diff --git a/router/client/src/v3/pb/.gitignore b/router/client/src/v3/pb/.gitignore deleted file mode 100644 index 72e8ffc0..00000000 --- a/router/client/src/v3/pb/.gitignore +++ /dev/null @@ -1 +0,0 @@ -* diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/chat_template.rs similarity index 67% rename from router/src/infer/v3/scheduler.rs rename to router/src/infer/chat_template.rs index 26cd9584..24a00352 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/chat_template.rs @@ -1,528 +1,85 @@ -/// Batching and inference logic -use crate::infer::v3::queue::{Entry, Queue}; -use crate::infer::{ - GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler, +use crate::infer::InferError; +use crate::{ + ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken, }; -use crate::validation::ValidGenerateRequest; -use crate::{FinishReason, PrefillToken, Token}; -use nohash_hasher::IntMap; -use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, -}; -use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient}; -use text_generation_client::ClientError; -use tokio::sync::mpsc::error::SendError; -use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit}; -use tokio::time::Instant; -use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::{info_span, instrument, Instrument, Span}; +use minijinja::{Environment, ErrorKind, Template}; +use minijinja_contrib::pycompat; -pub(crate) struct SchedulerV3 { - /// Request queue - queue: Queue, - /// Notify batcher on queue appends - batching_task_notifier: Arc, +/// Raise a exception (custom function) used in the chat templates +pub(crate) fn raise_exception(err_text: String) -> Result { + Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text)) } -impl SchedulerV3 { - #[allow(clippy::too_many_arguments)] +#[derive(Clone)] +pub(crate) struct ChatTemplate { + template: Template<'static, 'static>, + bos_token: Option, + eos_token: Option, + use_default_tool_template: bool, +} + +impl ChatTemplate { pub(crate) fn new( - client: ShardedClient, - waiting_served_ratio: f32, - max_batch_prefill_tokens: u32, - max_batch_total_tokens: u32, - max_waiting_tokens: usize, - max_batch_size: Option, - requires_padding: bool, - window_size: Option, - speculate: u32, - generation_health: Arc, + template: String, + bos_token: Option, + eos_token: Option, ) -> Self { - let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { - matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") - } else { - false - }; - let block_size = if flashdecoding { 256 } else { 16 }; - let queue = Queue::new( - requires_padding, - block_size, - window_size, - speculate, - max_batch_total_tokens, - ); - let batching_task_notifier = Arc::new(Notify::new()); + let mut env = Box::new(Environment::new()); + // enable things like .strip() or .capitalize() + env.set_unknown_method_callback(pycompat::unknown_method_callback); + let template_str = template.into_boxed_str(); + env.add_function("raise_exception", raise_exception); - // Spawn batching background task that contains all the inference logic - tokio::spawn(batching_task( - client, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - queue.clone(), - batching_task_notifier.clone(), - generation_health, - )); + // check if contains the tools variable within the template + let use_default_tool_template = + !template_str.as_ref().replace(' ', "").contains("{{tools}}"); + // leaking env and template_str as read-only, static resources for performance. + let template = Box::leak(env) + .template_from_str(Box::leak(template_str)) + .unwrap(); Self { - queue, - batching_task_notifier, + template, + bos_token: bos_token.map(|token| token.as_str().to_string()), + eos_token: eos_token.map(|token| token.as_str().to_string()), + use_default_tool_template, } } -} -impl Scheduler for SchedulerV3 { - #[instrument(skip_all)] - fn schedule( + pub(crate) fn apply( &self, - request: ValidGenerateRequest, - permit: OwnedSemaphorePermit, - ) -> Result { - // MPSC channel to communicate with the background batching task - let (response_tx, response_rx) = mpsc::unbounded_channel(); - let input_length = request.input_length; - - // Append the request to the queue - self.queue.append(Entry { - request, - response_tx, - span: Span::current(), - temp_span: None, - queue_time: Instant::now(), - batch_time: None, - block_allocation: None, - }); - - // Notify the background task that we have a new entry in the queue that needs - // to be batched - self.batching_task_notifier.notify_one(); - - // Return stream - Ok(( - permit, - input_length, - UnboundedReceiverStream::new(response_rx), - )) - } -} - -/// Batching logic -/// Will be launched in a background Tokio task -/// -/// Batches requests and sends them to the inference server -#[allow(clippy::too_many_arguments)] -pub(crate) async fn batching_task( - mut client: ShardedClient, - waiting_served_ratio: f32, - max_batch_prefill_tokens: u32, - max_batch_total_tokens: u32, - max_waiting_tokens: usize, - max_batch_size: Option, - queue: Queue, - notifier: Arc, - generation_health: Arc, -) { - // Infinite loop - loop { - // Wait for a notification from the Infer struct - notifier.notified().await; - - // Get the next batch from the queue - // This batch might be smaller than the maximum batch size if there are not enough requests - // waiting in the queue - while let Some((mut entries, batch, span)) = queue - .next_batch( - None, - max_batch_size, - max_batch_prefill_tokens, - max_batch_total_tokens, - ) - .await - { - let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health) - .instrument(span) - .await; - let mut waiting_tokens = 1; - - // We loop until we do not receive any cached batch from the inference server (== until - // all requests have met their stopping criteria) - while let Some(batch) = cached_batch { - // Get current batch info - let batch_size = batch.size; - let batch_max_tokens = batch.max_tokens; - let mut batches = vec![batch]; - metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); - metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); - - let min_size = if waiting_tokens >= max_waiting_tokens { - // If we didn't onboard any new requests since >= max_waiting_tokens, we try - // to add a new batch even though its size might be small - None - } else { - // Minimum batch size - Some((batch_size as f32 * waiting_served_ratio).floor() as usize) - }; - - let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize); - - // Try to get a new batch - if let Some((mut new_entries, new_batch, span)) = queue - .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) - .await - { - // Tracking metrics - if min_size.is_some() { - metrics::counter!("tgi_batch_concat", "reason" => "backpressure") - .increment(1); - } else { - metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") - .increment(1); - } - - entries.iter_mut().for_each(|(_, entry)| { - // Create a new span to add the info that this entry is waiting - // because a new batch is being computed - let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); - // Add relationships - span.follows_from(&entry_waiting_span); - entry_waiting_span.follows_from(&span); - // Update entry - entry.temp_span = Some(entry_waiting_span); + mut messages: Vec, + grammar_with_prompt: Option<(GrammarType, String)>, + ) -> Result { + if self.use_default_tool_template { + if let Some(last_message) = messages.last_mut() { + if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { + last_message.content.push(MessageChunk::Text { + text: format!("\n---\n{}\n{}", tool_prompt, tools), }); - - // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = - prefill(&mut client, new_batch, &mut new_entries, &generation_health) - .instrument(span) - .await; - // Reset waiting counter - waiting_tokens = 1; - // Extend current batch with the new batch - if let Some(new_cached_batch) = new_cached_batch { - entries.extend(new_entries); - batches.push(new_cached_batch); - } } - - // Create span for this batch to add context to inference calls - let next_batch_size = entries.len(); - let next_batch_span = - info_span!(parent: None, "batch", batch_size = next_batch_size); - entries.iter_mut().for_each(|(_, entry)| { - // Create a new span to link the batch back to this entry - let entry_batch_span = info_span!(parent: &entry.span, "infer"); - // Add relationships - next_batch_span.follows_from(&entry_batch_span); - entry_batch_span.follows_from(&next_batch_span); - // Update entry - entry.temp_span = Some(entry_batch_span); - }); - - cached_batch = decode(&mut client, batches, &mut entries, &generation_health) - .instrument(next_batch_span) - .await; - waiting_tokens += 1; - } - metrics::gauge!("tgi_batch_current_size").set(0.0); - metrics::gauge!("tgi_batch_current_max_tokens").set(0.0); - } - } -} - -#[instrument(skip_all)] -async fn prefill( - client: &mut ShardedClient, - batch: Batch, - entries: &mut IntMap, - generation_health: &Arc, -) -> Option { - let start_time = Instant::now(); - let batch_id = batch.id; - metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); - - match client.prefill(batch).await { - Ok((generations, next_batch, timings)) => { - // Update health - generation_health.store(true, Ordering::SeqCst); - - let start_filtering_time = Instant::now(); - // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); - - // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries).await; - - metrics::histogram!("tgi_batch_forward_duration","method" => "prefill") - .record(timings.forward.as_secs_f64()); - metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") - .record(timings.decode.as_secs_f64()); - metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill") - .record(start_filtering_time.elapsed().as_secs_f64()); - metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill") - .record(start_time.elapsed().as_secs_f64()); - metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1); - next_batch - } - // If we have an error, we discard the whole batch - Err(err) => { - // Update health - generation_health.store(false, Ordering::SeqCst); - let _ = client.clear_cache(Some(batch_id)).await; - send_errors(err, entries); - metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1); - None - } - } -} - -#[instrument(skip_all)] -async fn decode( - client: &mut ShardedClient, - batches: Vec, - entries: &mut IntMap, - generation_health: &Arc, -) -> Option { - let start_time = Instant::now(); - let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); - metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); - - match client.decode(batches).await { - Ok((generations, next_batch, timings)) => { - // Update health - generation_health.store(true, Ordering::SeqCst); - - let start_filtering_time = Instant::now(); - // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); - - // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries).await; - - if let Some(concat_duration) = timings.concat { - metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") - .record(concat_duration.as_secs_f64()); - } - metrics::histogram!("tgi_batch_forward_duration", "method" => "decode") - .record(timings.forward.as_secs_f64()); - metrics::histogram!("tgi_batch_decode_duration", "method" => "decode") - .record(timings.decode.as_secs_f64()); - metrics::histogram!("tgi_batch_filter_duration", "method" => "decode") - .record(start_filtering_time.elapsed().as_secs_f64()); - metrics::histogram!("tgi_batch_inference_duration", "method" => "decode") - .record(start_time.elapsed().as_secs_f64()); - metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1); - next_batch - } - // If we have an error, we discard the whole batch - Err(err) => { - generation_health.store(false, Ordering::SeqCst); - for id in batch_ids { - let _ = client.clear_cache(Some(id)).await; - } - send_errors(err, entries); - metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1); - None - } - } -} - -/// Filter a `batch` and remove all requests not present in `entries` -#[instrument(skip_all)] -async fn filter_batch( - client: &mut ShardedClient, - next_batch: Option, - entries: &IntMap, -) -> Option { - let mut batch = next_batch?; - - // No need to filter - if batch.size as usize == entries.len() { - return Some(batch); - } - - let id = batch.id; - - // Retain only requests that are still in entries - batch.request_ids.retain(|id| entries.contains_key(id)); - - if batch.request_ids.is_empty() { - // All requests have been filtered out - // Next batch is now empty - // Clear it from the Python shards cache - // We unwrap here as we need to panic since we cannot recover if this method fails - client.clear_cache(Some(id)).await.unwrap(); - None - } else { - // Filter Python shard cache - // We unwrap here as we need to panic since we cannot recover if this method fails - client.filter_batch(id, batch.request_ids).await.unwrap() - } -} - -/// Send one or multiple `InferStreamResponse` to Infer for all `entries` -/// and filter entries -#[instrument(skip_all)] -fn filter_send_generations(generations: Vec, entries: &mut IntMap) { - generations.into_iter().for_each(|generation| { - let id = generation.request_id; - // Get entry - // We can `expect` here as the request id should always be in the entries - let entry = entries - .get(&id) - .expect("ID not found in entries. This is a bug."); - - // Create and enter a span to link this function back to the entry - let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); - // Send generation responses back to the infer task - // If the receive an error from the Flume channel, it means that the client dropped the - // request and we need to stop generating hence why we unwrap_or(true) - let stopped = send_responses(generation, entry).map_err(|err| { - tracing::error!("Entry response channel error."); - metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); - err - }).unwrap_or(true); - if stopped { - entries.remove(&id).expect("ID not found in entries. This is a bug."); - } - }); -} - -/// Send responses through the `entry` response channel -fn send_responses( - generation: Generation, - entry: &Entry, -) -> Result>>> { - // Return directly if the channel is disconnected - if entry.response_tx.is_closed() { - metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); - return Ok(true); - } - - let mut stopped = false; - - if let Some(prefill_tokens) = generation.prefill_tokens { - // Create Token objects - // We do that here instead of in the Python code as Rust for loops are faster - let prefill_tokens = prefill_tokens - .ids - .into_iter() - .zip(prefill_tokens.logprobs) - .zip(prefill_tokens.texts) - .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) - .collect(); - - // Send message - entry - .response_tx - .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; - } - - // Create last Token - let tokens_ = generation.tokens.expect("Non empty tokens in generation"); - let n = tokens_.ids.len(); - metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64); - let mut iterator = tokens_ - .ids - .into_iter() - .zip(tokens_.logprobs) - .zip(tokens_.texts) - .zip(tokens_.is_special) - .enumerate() - .peekable(); - while let Some((i, (((id, logprob), text), special))) = iterator.next() { - let token = Token { - id, - text, - logprob, - special, - }; - let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) { - top_tokens_ - .ids - .iter() - .zip(top_tokens_.logprobs.iter()) - .zip(top_tokens_.texts.iter()) - .zip(top_tokens_.is_special.iter()) - .map(|(((&id, &logprob), text), &special)| Token { - id, - text: text.to_string(), - logprob, - special, - }) - .collect() - } else { - vec![] - }; - match (&generation.generated_text, iterator.peek()) { - (Some(generated_text), None) => { - // Generation has ended - stopped = true; - // Send message - entry.response_tx.send(Ok(InferStreamResponse::End { - token, - top_tokens, - generated_text: GeneratedText::from(generated_text.clone()), - queued: entry.queue_time, - start: entry.batch_time.unwrap(), - }))?; - } - _ => { - // Send message - entry - .response_tx - .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; } } - } - Ok(stopped) -} + let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); -/// Send errors to Infer for all `entries` -#[instrument(skip_all)] -fn send_errors(error: ClientError, entries: &mut IntMap) { - entries.drain().for_each(|(_, entry)| { - // Create and enter a span to link this function back to the entry - let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); - let err = InferError::GenerationError(error.to_string()); - metrics::counter!("tgi_request_failure", "err" => "generation").increment(1); - tracing::error!("{err}"); - - // unwrap_or is valid here as we don't care if the receiver is gone. - entry - .response_tx - .send(Err(err)) - .unwrap_or(()); - }); -} - -impl From for GeneratedText { - fn from(value: text_generation_client::v3::GeneratedText) -> Self { - let v3_finish_reason = - text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap(); - let finish_reason = match v3_finish_reason { - text_generation_client::v3::FinishReason::Length => FinishReason::Length, - text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken, - text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence, - }; - - Self { - text: value.text, - generated_tokens: value.generated_tokens, - finish_reason, - seed: value.seed, - } + self.template + .render(ChatTemplateInputs { + messages, + bos_token: self.bos_token.as_deref(), + eos_token: self.eos_token.as_deref(), + add_generation_prompt: true, + tools: None, + tools_prompt: None, + }) + .map_err(InferError::TemplateError) } } // tests #[cfg(test)] mod tests { - use crate::infer::raise_exception; + use crate::infer::chat_template::raise_exception; use crate::{ChatTemplateInputs, TextMessage}; use minijinja::Environment; diff --git a/router/src/infer/health.rs b/router/src/infer/health.rs deleted file mode 100644 index 4320c1a4..00000000 --- a/router/src/infer/health.rs +++ /dev/null @@ -1,34 +0,0 @@ -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; -use text_generation_client::Health; - -#[derive(Clone)] -pub(crate) struct HealthCheck { - client: Arc, - generation_health: Arc, -} - -impl HealthCheck { - pub(crate) fn new( - client: Arc, - generation_health: Arc, - ) -> Self { - Self { - client, - generation_health, - } - } - - pub(crate) async fn check(&mut self) -> bool { - let value = if self.generation_health.load(Ordering::SeqCst) { - // Generation is healthy, we only check that the shards can allocate on device - self.client.device_health().await - } else { - self.client.model_health().await - } - .is_ok(); - // Update generation health - self.generation_health.store(value, Ordering::SeqCst); - value - } -} diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index db9070d4..534a2647 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -1,23 +1,18 @@ -mod health; -pub(crate) mod v2; -pub(crate) mod v3; - -pub(crate) use health::HealthCheck; +// pub(crate) mod v2; +mod chat_template; +pub mod tool_grammar; use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; +use crate::GrammarType; use crate::{ - ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, - HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, ToolChoice, -}; -use crate::{ - FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools, + ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig, + Message, PrefillToken, Token, }; +use async_trait::async_trait; +use chat_template::ChatTemplate; use futures::future::try_join_all; -use minijinja::{Environment, ErrorKind, Template}; -use minijinja_contrib::pycompat; - -use serde_json::{json, Map, Value}; -use std::collections::HashMap; +use minijinja::ErrorKind; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use thiserror::Error; use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError}; @@ -26,12 +21,14 @@ use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; use tracing::instrument; -pub(crate) trait Scheduler { +#[async_trait] +pub trait Backend { fn schedule( &self, request: ValidGenerateRequest, - permit: OwnedSemaphorePermit, - ) -> Result; + ) -> Result>, InferError>; + + async fn health(&self, current_health: bool) -> bool; } /// Inference struct @@ -39,18 +36,20 @@ pub(crate) trait Scheduler { pub struct Infer { /// Validation validation: Validation, - /// Request scheduler - scheduler: Arc, + /// Request backend + backend: Arc, /// Chat template chat_template: Option, /// Inference limit limit_concurrent_requests: Arc, + /// Backend health + backend_health: Arc, } impl Infer { #[allow(clippy::too_many_arguments)] pub(crate) fn new( - scheduler: Arc, + backend: impl Backend + Send + Sync + 'static, validation: Validation, max_concurrent_requests: usize, tokenizer_config: HubTokenizerConfig, @@ -71,18 +70,22 @@ impl Infer { // Inference limit with a semaphore let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); + // Backend health + let backend_health = Arc::new(AtomicBool::new(false)); + Self { validation, - scheduler, + backend: Arc::new(backend), chat_template, limit_concurrent_requests: semaphore, + backend_health, } } /// Add a new request to the queue and return a stream of InferStreamResponse #[instrument(skip_all)] - pub(crate) async fn generate_stream( - &self, + pub(crate) async fn generate_stream<'a>( + &'a self, request: GenerateRequest, ) -> Result { // Limit concurrent requests by acquiring a permit from the semaphore @@ -103,7 +106,10 @@ impl Infer { err })?; - self.scheduler.schedule(valid_request, permit) + let input_length = valid_request.input_length; + let generation_stream = self.backend.schedule(valid_request)?; + + Ok((permit, input_length, generation_stream)) } /// Tokenizer the input @@ -155,7 +161,7 @@ impl Infer { let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); // Create stream and keep semaphore permit as long as generate lives - let (_permit, _input_length, mut stream) = self.generate_stream(request).await?; + let (_permit, _input_length, stream) = self.generate_stream(request).await?; // Return values let mut result_prefill = Vec::new(); @@ -165,6 +171,8 @@ impl Infer { let mut result_start = None; let mut result_queued = None; + let mut stream = Box::pin(stream); + // Iterate on stream while let Some(response) = stream.next().await { match response? { @@ -256,207 +264,15 @@ impl Infer { let best_response = infer_responses.remove(max_index); Ok((best_response, infer_responses)) } -} -/// Raise a exception (custom function) used in the chat templates -fn raise_exception(err_text: String) -> Result { - Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text)) -} - -#[derive(Clone)] -struct ChatTemplate { - template: Template<'static, 'static>, - bos_token: Option, - eos_token: Option, - use_default_tool_template: bool, -} - -impl ChatTemplate { - fn new( - template: String, - bos_token: Option, - eos_token: Option, - ) -> Self { - let mut env = Box::new(Environment::new()); - // enable things like .strip() or .capitalize() - env.set_unknown_method_callback(pycompat::unknown_method_callback); - let template_str = template.into_boxed_str(); - env.add_function("raise_exception", raise_exception); - - // check if contains the tools variable within the template - let use_default_tool_template = - !template_str.as_ref().replace(' ', "").contains("{{tools}}"); - // leaking env and template_str as read-only, static resources for performance. - let template = Box::leak(env) - .template_from_str(Box::leak(template_str)) - .unwrap(); - - Self { - template, - bos_token: bos_token.map(|token| token.as_str().to_string()), - eos_token: eos_token.map(|token| token.as_str().to_string()), - use_default_tool_template, - } - } - - fn apply( - &self, - mut messages: Vec, - grammar_with_prompt: Option<(GrammarType, String)>, - ) -> Result { - if self.use_default_tool_template { - if let Some(last_message) = messages.last_mut() { - if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { - last_message.content.push(MessageChunk::Text { - text: format!("\n---\n{}\n{}", tool_prompt, tools), - }); - } - } - } - - let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); - - self.template - .render(ChatTemplateInputs { - messages, - bos_token: self.bos_token.as_deref(), - eos_token: self.eos_token.as_deref(), - add_generation_prompt: true, - tools: None, - tools_prompt: None, - }) - .map_err(InferError::TemplateError) - } -} - -pub struct ToolGrammar {} - -impl ToolGrammar { - // find a tool by name - fn find_tool_by_name(tools: &[Tool], name: &str) -> Result { - tools - .iter() - .find(|tool| tool.function.name == name) - .cloned() - .ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name))) - } - - pub fn apply( - tools: Option>, - tool_choice: ToolChoice, - ) -> Result, InferError> { - // if no tools are provided, we return None - let tools = match tools { - Some(tools) if !tools.is_empty() => tools, - _ => return Ok(None), - }; - - let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf); - - // if tools are provided and no tool_choice we default to the OneOf - let tools_to_use = match tool_choice { - ToolType::FunctionName(name) => { - vec![Self::find_tool_by_name(&tools, &name)?] - } - ToolType::Function { function } => { - vec![Self::find_tool_by_name(&tools, &function.name)?] - } - ToolType::OneOf => tools, - ToolType::NoTool => return Ok(None), - }; - - // adds the error notification function for LLM feedback if required - let mut text_response_properties = Map::new(); - text_response_properties.insert( - "error".to_string(), - serde_json::json!({ - "type": "string", - "description": "The error or issue to notify" - }), - ); - text_response_properties.insert( - "_name".to_string(), - serde_json::json!({ - "type": "string", - "const": "notify_error" - }), - ); - - let functions: HashMap = tools_to_use - .iter() - .map(|tool| { - let func = tool.function.clone(); - - // Clone the existing parameters, which are expected to be a JSON object - let mut params = if let Value::Object(params) = &func.arguments { - params.clone() - } else { - Map::new() - }; - - // Insert the function's description at the top level, outside of properties - params.insert( - "description".to_string(), - Value::String(func.description.clone().unwrap_or_default()), - ); - - // Ensure 'properties' exists and is an object - let properties = params - .entry("properties".to_string()) - .or_insert_with(|| json!({})) - .as_object_mut() - .unwrap(); - - // Insert the constant for the function name inside 'properties' - properties.insert( - "_name".to_string(), - json!({ - "type": "string", - "const": func.name.clone(), - // "description": "The name of the function" - }), - ); - - // Check if 'required' exists, and it is an array. If not, create an empty array. - let required = params - .entry("required".to_string()) - .or_insert_with(|| json!([])) - .as_array_mut() - .unwrap(); - - // Add 'name' to the 'required' array if it is not already present - if !required.iter().any(|r| r == "_name") { - required.push(json!("_name")); - } - - (func.name, Value::Object(params)) - }) - .chain([( - "notify_error".to_string(), - serde_json::json!({ - "properties": text_response_properties, - "required": ["error", "_name"], - "type": "object" - }), - )]) - .collect(); - - let tools = Tools { - functions_map: FunctionsMap { functions }, - properties: Properties { - function: tools_to_use - .iter() - .map(|tool| FunctionRef { - ref_path: format!("#/$functions/{}", tool.function.name.clone()), - }) - .chain(std::iter::once(FunctionRef { - ref_path: "#/$functions/notify_error".to_string(), - })) - .collect(), - }, - }; - - Ok(Some(tools)) + #[instrument(skip(self))] + pub(crate) async fn health(&self) -> bool { + let health = self + .backend + .health(self.backend_health.load(Ordering::SeqCst)) + .await; + self.backend_health.store(health, Ordering::SeqCst); + health } } @@ -468,15 +284,15 @@ pub(crate) type GenerateStreamResponse = ( ); #[derive(Debug)] -pub(crate) struct GeneratedText { - pub(crate) text: String, - pub(crate) generated_tokens: u32, - pub(crate) finish_reason: FinishReason, - pub(crate) seed: Option, +pub struct GeneratedText { + pub text: String, + pub generated_tokens: u32, + pub finish_reason: FinishReason, + pub seed: Option, } #[derive(Debug)] -pub(crate) enum InferStreamResponse { +pub enum InferStreamResponse { // Optional first message Prefill(Vec), // Intermediate messages diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs new file mode 100644 index 00000000..05027f30 --- /dev/null +++ b/router/src/infer/tool_grammar.rs @@ -0,0 +1,135 @@ +use crate::infer::InferError; +use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools}; +use serde_json::{json, Map, Value}; +use std::collections::HashMap; + +pub(crate) struct ToolGrammar {} + +impl ToolGrammar { + // find a tool by name + fn find_tool_by_name(tools: &[Tool], name: &str) -> Result { + tools + .iter() + .find(|tool| tool.function.name == name) + .cloned() + .ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name))) + } + + pub fn apply( + tools: Option>, + tool_choice: ToolChoice, + ) -> Result, InferError> { + // if no tools are provided, we return None + let tools = match tools { + Some(tools) if !tools.is_empty() => tools, + _ => return Ok(None), + }; + + let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf); + + // if tools are provided and no tool_choice we default to the OneOf + let tools_to_use = match tool_choice { + ToolType::FunctionName(name) => { + vec![Self::find_tool_by_name(&tools, &name)?] + } + ToolType::Function { function } => { + vec![Self::find_tool_by_name(&tools, &function.name)?] + } + ToolType::OneOf => tools, + ToolType::NoTool => return Ok(None), + }; + + // adds the error notification function for LLM feedback if required + let mut text_response_properties = Map::new(); + text_response_properties.insert( + "error".to_string(), + serde_json::json!({ + "type": "string", + "description": "The error or issue to notify" + }), + ); + text_response_properties.insert( + "_name".to_string(), + serde_json::json!({ + "type": "string", + "const": "notify_error" + }), + ); + + let functions: HashMap = tools_to_use + .iter() + .map(|tool| { + let func = tool.function.clone(); + + // Clone the existing parameters, which are expected to be a JSON object + let mut params = if let Value::Object(params) = &func.arguments { + params.clone() + } else { + Map::new() + }; + + // Insert the function's description at the top level, outside of properties + params.insert( + "description".to_string(), + Value::String(func.description.clone().unwrap_or_default()), + ); + + // Ensure 'properties' exists and is an object + let properties = params + .entry("properties".to_string()) + .or_insert_with(|| json!({})) + .as_object_mut() + .unwrap(); + + // Insert the constant for the function name inside 'properties' + properties.insert( + "_name".to_string(), + json!({ + "type": "string", + "const": func.name.clone(), + // "description": "The name of the function" + }), + ); + + // Check if 'required' exists, and it is an array. If not, create an empty array. + let required = params + .entry("required".to_string()) + .or_insert_with(|| json!([])) + .as_array_mut() + .unwrap(); + + // Add 'name' to the 'required' array if it is not already present + if !required.iter().any(|r| r == "_name") { + required.push(json!("_name")); + } + + (func.name, Value::Object(params)) + }) + .chain([( + "notify_error".to_string(), + serde_json::json!({ + "properties": text_response_properties, + "required": ["error", "_name"], + "type": "object" + }), + )]) + .collect(); + + let tools = Tools { + functions_map: FunctionsMap { functions }, + properties: Properties { + function: tools_to_use + .iter() + .map(|tool| FunctionRef { + ref_path: format!("#/$functions/{}", tool.function.name.clone()), + }) + .chain(std::iter::once(FunctionRef { + ref_path: "#/$functions/notify_error".to_string(), + })) + .collect(), + }, + }; + + Ok(Some(tools)) + } +} diff --git a/router/src/infer/v2/mod.rs b/router/src/infer/v2/mod.rs index 8b4f6bab..6a91a433 100644 --- a/router/src/infer/v2/mod.rs +++ b/router/src/infer/v2/mod.rs @@ -1,4 +1,4 @@ mod queue; mod scheduler; -pub(crate) use scheduler::SchedulerV2; +pub(crate) use scheduler::BackendV2; diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index 97379bc5..3d6c36cf 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -1,7 +1,7 @@ /// Batching and inference logic use crate::infer::v2::queue::{Entry, Queue}; use crate::infer::{ - GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler, + Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, }; use crate::validation::ValidGenerateRequest; use crate::{FinishReason, PrefillToken, Token}; @@ -18,14 +18,14 @@ use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{info_span, instrument, Instrument, Span}; -pub(crate) struct SchedulerV2 { +pub(crate) struct BackendV2 { /// Request queue queue: Queue, /// Notify batcher on queue appends batching_task_notifier: Arc, } -impl SchedulerV2 { +impl BackendV2 { #[allow(clippy::too_many_arguments)] pub(crate) fn new( client: ShardedClient, @@ -69,7 +69,7 @@ impl SchedulerV2 { } } -impl Scheduler for SchedulerV2 { +impl Backend for BackendV2 { #[instrument(skip_all)] fn schedule( &self, diff --git a/router/src/infer/v3/mod.rs b/router/src/infer/v3/mod.rs deleted file mode 100644 index f9effab8..00000000 --- a/router/src/infer/v3/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod block_allocator; -mod queue; -mod scheduler; - -pub(crate) use scheduler::SchedulerV3; diff --git a/router/src/lib.rs b/router/src/lib.rs index b6e0d09d..14bb8270 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,11 +1,12 @@ /// Text Generation Inference Webserver pub mod config; -mod infer; +pub mod infer; pub mod server; -mod validation; +pub mod validation; #[cfg(feature = "kserve")] mod kserve; +pub mod logging; pub mod usage_stats; @@ -148,12 +149,13 @@ pub struct Info { pub model_id: String, #[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")] pub model_sha: Option, - #[schema(example = "torch.float16")] - pub model_dtype: String, - #[schema(example = "cuda")] - pub model_device_type: String, + // #[schema(example = "torch.float16")] + // pub model_dtype: String, + // #[schema(example = "cuda")] + // pub model_device_type: String, #[schema(nullable = true, example = "text-generation")] pub model_pipeline_tag: Option, + /// Router Parameters #[schema(example = "128")] pub max_concurrent_requests: usize, @@ -165,18 +167,11 @@ pub struct Info { pub max_input_tokens: usize, #[schema(example = "2048")] pub max_total_tokens: usize, - #[schema(example = "1.2")] - pub waiting_served_ratio: f32, - #[schema(example = "32000")] - pub max_batch_total_tokens: u32, - #[schema(example = "20")] - pub max_waiting_tokens: usize, - #[schema(nullable = true, example = "null")] - pub max_batch_size: Option, #[schema(example = "2")] pub validation_workers: usize, #[schema(example = "32")] pub max_client_batch_size: usize, + /// Router Info #[schema(example = "text-generation-router")] pub router: &'static str, @@ -1068,23 +1063,23 @@ impl From for GenerateRequest { #[derive(Debug, Serialize, ToSchema)] pub struct PrefillToken { #[schema(example = 0)] - id: u32, + pub id: u32, #[schema(example = "test")] - text: String, + pub text: String, #[schema(nullable = true, example = - 0.34)] - logprob: f32, + pub logprob: f32, } #[derive(Debug, Serialize, ToSchema, Clone)] pub struct Token { #[schema(example = 0)] - id: u32, + pub id: u32, #[schema(example = "test")] - text: String, + pub text: String, #[schema(nullable = true, example = - 0.34)] - logprob: f32, + pub logprob: f32, #[schema(example = "false")] - special: bool, + pub special: bool, } #[derive(Debug, Serialize, ToSchema)] @@ -1102,7 +1097,7 @@ pub struct SimpleToken { #[derive(Debug, Serialize, ToSchema)] #[serde(rename_all(serialize = "snake_case"))] #[schema(example = "Length")] -pub(crate) enum FinishReason { +pub enum FinishReason { #[schema(rename = "length")] Length, #[serde(rename = "eos_token")] diff --git a/router/src/logging.rs b/router/src/logging.rs new file mode 100644 index 00000000..5a98ef57 --- /dev/null +++ b/router/src/logging.rs @@ -0,0 +1,81 @@ +use opentelemetry::sdk::propagation::TraceContextPropagator; +use opentelemetry::sdk::trace; +use opentelemetry::sdk::trace::Sampler; +use opentelemetry::sdk::Resource; +use opentelemetry::{global, KeyValue}; +use opentelemetry_otlp::WithExportConfig; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer}; + +/// Init logging using env variables LOG_LEVEL and LOG_FORMAT: +/// - otlp_endpoint is an optional URL to an Open Telemetry collector +/// - otlp_service_name service name to appear in APM +/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO) +/// - LOG_FORMAT may be TEXT or JSON (default to TEXT) +/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms) +pub fn init_logging(otlp_endpoint: Option, otlp_service_name: String, json_output: bool) { + let mut layers = Vec::new(); + + // STDOUT/STDERR layer + let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string()); + let fmt_layer = tracing_subscriber::fmt::layer() + .with_file(true) + .with_ansi(ansi) + .with_line_number(true); + + let fmt_layer = match json_output { + true => fmt_layer.json().flatten_event(true).boxed(), + false => fmt_layer.boxed(), + }; + layers.push(fmt_layer); + + // OpenTelemetry tracing layer + if let Some(otlp_endpoint) = otlp_endpoint { + global::set_text_map_propagator(TraceContextPropagator::new()); + + let tracer = opentelemetry_otlp::new_pipeline() + .tracing() + .with_exporter( + opentelemetry_otlp::new_exporter() + .tonic() + .with_endpoint(otlp_endpoint), + ) + .with_trace_config( + trace::config() + .with_resource(Resource::new(vec![KeyValue::new( + "service.name", + otlp_service_name, + )])) + .with_sampler(Sampler::AlwaysOn), + ) + .install_batch(opentelemetry::runtime::Tokio); + + if let Ok(tracer) = tracer { + layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed()); + init_tracing_opentelemetry::init_propagator().unwrap(); + }; + } + + // Filter events with LOG_LEVEL + let varname = "LOG_LEVEL"; + let env_filter = if let Ok(log_level) = std::env::var(varname) { + // Override to avoid simple logs to be spammed with tokio level informations + let log_level = match &log_level[..] { + "warn" => "text_generation_launcher=warn,text_generation_router=warn", + "info" => "text_generation_launcher=info,text_generation_router=info", + "debug" => "text_generation_launcher=debug,text_generation_router=debug", + log_level => log_level, + }; + EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .parse_lossy(log_level) + } else { + EnvFilter::new("info") + }; + + tracing_subscriber::registry() + .with(env_filter) + .with(layers) + .init(); +} diff --git a/router/src/main.rs b/router/src/main.rs.back similarity index 100% rename from router/src/main.rs rename to router/src/main.rs.back diff --git a/router/src/server.rs b/router/src/server.rs index 0fd5aade..ccbd1535 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,14 +1,13 @@ /// HTTP Server logic use crate::config::Config; -use crate::infer::v2::SchedulerV2; -use crate::infer::v3::SchedulerV3; -use crate::infer::{HealthCheck, Scheduler}; -use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar}; +use crate::infer::tool_grammar::ToolGrammar; +use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse}; #[cfg(feature = "kserve")] use crate::kserve::{ kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer, kserve_model_metadata, kserve_model_metadata_ready, }; +use crate::usage_stats; use crate::validation::ValidationError; use crate::{ BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, @@ -27,7 +26,7 @@ use crate::{ use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; use async_stream::__private::AsyncStream; use axum::extract::Extension; -use axum::http::{HeaderMap, Method, StatusCode}; +use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; @@ -37,15 +36,18 @@ use futures::stream::StreamExt; use futures::stream::{FuturesOrdered, FuturesUnordered}; use futures::Stream; use futures::TryStreamExt; +use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; +use hf_hub::{Cache, Repo, RepoType}; use http::header::AUTHORIZATION; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use serde_json::Value; use std::convert::Infallible; -use std::net::SocketAddr; -use std::sync::atomic::AtomicBool; -use std::sync::Arc; -use text_generation_client::{v2, v3, ClientError, ShardInfo}; +use std::fs::File; +use std::io::BufReader; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::path::{Path, PathBuf}; use thiserror::Error; +use tokenizers::processors::template::TemplateProcessing; use tokenizers::Tokenizer; use tokio::select; use tokio::signal; @@ -124,12 +126,10 @@ responses( example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})), ) )] -#[instrument(skip(health))] +#[instrument(skip(infer))] /// Health check method -async fn health( - mut health: Extension, -) -> Result<(), (StatusCode, Json)> { - match health.check().await { +async fn health(infer: Extension) -> Result<(), (StatusCode, Json)> { + match infer.health().await { true => Ok(()), false => Err(( StatusCode::SERVICE_UNAVAILABLE, @@ -430,8 +430,9 @@ async fn generate_stream_internal( } else { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives - Ok((_permit, _input_length, mut response_stream)) => { + Ok((_permit, _input_length, response_stream)) => { let mut index = 0; + let mut response_stream = Box::pin(response_stream); // Server-Sent Event stream while let Some(response) = response_stream.next().await { index += 1; @@ -1396,262 +1397,456 @@ async fn metrics(prom_handle: Extension) -> String { #[derive(Clone, Debug)] pub(crate) struct ComputeType(String); +// OpenAPI documentation +#[derive(OpenApi)] +#[openapi( +paths( +health, +get_model_info, +compat_generate, +generate, +generate_stream, +chat_completions, +completions, +tokenize, +metrics, +), +components( +schemas( +Info, +CompatGenerateRequest, +GenerateRequest, +GrammarType, +ChatRequest, +Message, +MessageContent, +MessageChunk, +Url, +FunctionName, +OutputMessage, +TextMessage, +ToolCallMessage, +ToolCallDelta, +ChatCompletionComplete, +ChatCompletionChoice, +ChatCompletionDelta, +ChatCompletionChunk, +ChatCompletionLogprob, +ChatCompletionLogprobs, +ChatCompletionTopLogprob, +ChatCompletion, +CompletionRequest, +CompletionComplete, +Chunk, +Completion, +CompletionFinal, +Prompt, +GenerateParameters, +PrefillToken, +Token, +GenerateResponse, +TokenizeResponse, +SimpleToken, +BestOfSequence, +Details, +FinishReason, +StreamResponse, +StreamDetails, +ErrorResponse, +GrammarType, +Usage, +DeltaToolCall, +ToolType, +Tool, +ToolCall, +Function, +FunctionDefinition, +ToolChoice, +) +), +tags( +(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API") +), +info( +title = "Text Generation Inference", +license( +name = "Apache 2.0", +url = "https://www.apache.org/licenses/LICENSE-2.0" +) +) +)] +pub struct ApiDoc; + +pub fn schema() -> ApiDoc { + ApiDoc +} + /// Serving method #[allow(clippy::too_many_arguments)] pub async fn run( - master_shard_uds_path: String, - model_info: HubModelInfo, - compat_return_full_text: bool, + backend: impl Backend + Send + Sync + 'static, max_concurrent_requests: usize, max_best_of: usize, max_stop_sequences: usize, max_top_n_tokens: u32, max_input_tokens: usize, max_total_tokens: usize, - waiting_served_ratio: f32, - max_batch_prefill_tokens: u32, - max_batch_total_tokens: Option, - max_waiting_tokens: usize, - max_batch_size: Option, - tokenizer: Option, - config: Option, validation_workers: usize, - addr: SocketAddr, - allow_origin: Option, api_key: Option, + tokenizer_name: String, + tokenizer_config_path: Option, + revision: Option, + hostname: String, + port: u16, + cors_allow_origin: Option>, ngrok: bool, _ngrok_authtoken: Option, _ngrok_edge: Option, - tokenizer_config: HubTokenizerConfig, - preprocessor_config: Option, - processor_config: HubProcessorConfig, messages_api_enabled: bool, - grammar_support: bool, + disable_grammar_support: bool, max_client_batch_size: usize, - print_schema_command: bool, + disable_usage_stats: bool, + disable_crash_reports: bool, ) -> Result<(), WebServerError> { - // OpenAPI documentation - #[derive(OpenApi)] - #[openapi( - paths( - health, - get_model_info, - compat_generate, - generate, - generate_stream, - chat_completions, - completions, - tokenize, - metrics, - ), - components( - schemas( - Info, - CompatGenerateRequest, - GenerateRequest, - GrammarType, - ChatRequest, - Message, - MessageContent, - MessageChunk, - Url, - FunctionName, - OutputMessage, - TextMessage, - ToolCallMessage, - ToolCallDelta, - ChatCompletionComplete, - ChatCompletionChoice, - ChatCompletionDelta, - ChatCompletionChunk, - ChatCompletionLogprob, - ChatCompletionLogprobs, - ChatCompletionTopLogprob, - ChatCompletion, - CompletionRequest, - CompletionComplete, - Chunk, - Completion, - CompletionFinal, - Prompt, - GenerateParameters, - PrefillToken, - Token, - GenerateResponse, - TokenizeResponse, - SimpleToken, - BestOfSequence, - Details, - FinishReason, - StreamResponse, - StreamDetails, - ErrorResponse, - GrammarType, - Usage, - DeltaToolCall, - ToolType, - Tool, - ToolCall, - Function, - FunctionDefinition, - ToolChoice, - ) - ), - tags( - (name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API") - ), - info( - title = "Text Generation Inference", - license( - name = "Apache 2.0", - url = "https://www.apache.org/licenses/LICENSE-2.0" - ) - ) - )] - struct ApiDoc; + // CORS allowed origins + // map to go inside the option and then map to parse from String to HeaderValue + // Finally, convert to AllowOrigin + let allow_origin: Option = cors_allow_origin.map(|cors_allow_origin| { + AllowOrigin::list( + cors_allow_origin + .iter() + .map(|origin| origin.parse::().unwrap()), + ) + }); - // Create state - if print_schema_command { - let api_doc = ApiDoc::openapi(); - let api_doc = serde_json::to_string_pretty(&api_doc).unwrap(); - println!("{}", api_doc); - std::process::exit(0); + // Parse Huggingface hub token + let authorization_token = std::env::var("HF_TOKEN") + .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")) + .ok(); + + // Tokenizer instance + // This will only be used to validate payloads + let local_path = Path::new(&tokenizer_name); + + // Shared API builder initialization + let api_builder = || { + let mut builder = ApiBuilder::new() + .with_progress(false) + .with_token(authorization_token); + + if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") { + builder = builder.with_cache_dir(cache_dir.into()); + } + + builder + }; + + // Decide if we need to use the API based on the revision and local path + let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir(); + + // Initialize API if needed + #[derive(Clone)] + enum Type { + Api(Api), + Cache(Cache), + None, } - - // Open connection, get model info and warmup - let (scheduler, health_ext, shard_info, max_batch_total_tokens): ( - Arc, - HealthCheck, - ShardInfo, - u32, - ) = { - // Helper function to check both v2 and v3 - let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option| { - match max_supported_batch_total_tokens { - // Older models do not support automatic max-batch-total-tokens - None => { - let max_batch_total_tokens = max_batch_total_tokens.unwrap_or( - 16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)), - ); - tracing::warn!("Model does not support automatic max batch total tokens"); - Ok(max_batch_total_tokens) + let api = if use_api { + if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) { + let cache = std::env::var("HUGGINGFACE_HUB_CACHE") + .map_err(|_| ()) + .map(|cache_dir| Cache::new(cache_dir.into())) + .unwrap_or_else(|_| Cache::default()); + tracing::warn!("Offline mode active using cache defaults"); + Type::Cache(cache) + } else { + tracing::info!("Using the Hugging Face API"); + match api_builder().build() { + Ok(api) => Type::Api(api), + Err(_) => { + tracing::warn!("Unable to build the Hugging Face API"); + Type::None } - // Flash attention models return their max supported total tokens - Some(max_supported_batch_total_tokens) => { - // Warn if user added his own max-batch-total-tokens as we will ignore it - if max_batch_total_tokens.is_some() { - tracing::warn!( - "`--max-batch-total-tokens` is deprecated for Flash \ - Attention models." - ); - tracing::warn!( - "Inferred max batch total tokens: {max_supported_batch_total_tokens}" - ); - } - if max_total_tokens as u32 > max_supported_batch_total_tokens { - return Err(WebServerError::NotEnoughMemory(max_total_tokens)); - } - - Ok(max_supported_batch_total_tokens) - } - } - }; - - let generation_health = Arc::new(AtomicBool::new(false)); - - match v3::ShardedClient::connect_uds(master_shard_uds_path.clone()).await { - Ok(mut sharded_client) => { - // server is running on v3 - // Clear the cache; useful if the webserver rebooted - sharded_client - .clear_cache(None) - .await - .map_err(WebServerError::Cache)?; - // Get info from the shard - let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?; - - // Warmup model - tracing::info!("Warming up model"); - let max_batch_total_tokens = check_max_batch_total_tokens( - sharded_client - .warmup( - max_input_tokens as u32, - max_batch_prefill_tokens, - max_total_tokens as u32, - max_batch_size, - ) - .await - .map_err(WebServerError::Warmup)?, - )?; - - let health_ext = - HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone()); - let scheduler = Arc::new(SchedulerV3::new( - sharded_client, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - shard_info.requires_padding, - shard_info.window_size, - shard_info.speculate, - generation_health, - )); - tracing::info!("Using scheduler V3"); - - (scheduler, health_ext, shard_info, max_batch_total_tokens) - } - Err(_) => { - let mut sharded_client = v2::ShardedClient::connect_uds(master_shard_uds_path) - .await - .map_err(WebServerError::Connection)?; - - // server is running on v2 - // Clear the cache; useful if the webserver rebooted - sharded_client - .clear_cache(None) - .await - .map_err(WebServerError::Cache)?; - // Get info from the shard - let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?; - - // Warmup model - tracing::info!("Warming up model"); - let max_batch_total_tokens = check_max_batch_total_tokens( - sharded_client - .warmup( - max_input_tokens as u32, - max_batch_prefill_tokens, - max_total_tokens as u32, - max_batch_size, - ) - .await - .map_err(WebServerError::Warmup)?, - )?; - - let health_ext = - HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone()); - let scheduler = Arc::new(SchedulerV2::new( - sharded_client, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - shard_info.requires_padding, - shard_info.window_size, - shard_info.speculate, - generation_health, - )); - tracing::info!("Using scheduler V2"); - - (scheduler, health_ext, shard_info, max_batch_total_tokens) } } + } else { + Type::None }; - tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); + // Load tokenizer and model info + let ( + tokenizer_filename, + config_filename, + tokenizer_config_filename, + preprocessor_config_filename, + processor_config_filename, + model_info, + ) = match api { + Type::None => ( + Some(local_path.join("tokenizer.json")), + Some(local_path.join("config.json")), + Some(local_path.join("tokenizer_config.json")), + Some(local_path.join("preprocessor_config.json")), + Some(local_path.join("processor_config.json")), + None, + ), + Type::Api(api) => { + let api_repo = api.repo(Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.clone().unwrap_or_else(|| "main".to_string()), + )); + + let tokenizer_filename = match api_repo.get("tokenizer.json").await { + Ok(tokenizer_filename) => Some(tokenizer_filename), + Err(_) => get_base_tokenizer(&api, &api_repo).await, + }; + let config_filename = api_repo.get("config.json").await.ok(); + let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); + let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok(); + let processor_config_filename = api_repo.get("processor_config.json").await.ok(); + + let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await { + Some(model_info) + } else { + tracing::warn!("Could not retrieve model info from the Hugging Face hub."); + None + }; + ( + tokenizer_filename, + config_filename, + tokenizer_config_filename, + preprocessor_config_filename, + processor_config_filename, + model_info, + ) + } + Type::Cache(cache) => { + let repo = cache.repo(Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.clone().unwrap_or_else(|| "main".to_string()), + )); + ( + repo.get("tokenizer.json"), + repo.get("config.json"), + repo.get("tokenizer_config.json"), + repo.get("preprocessor_config.json"), + repo.get("processor_config.json"), + None, + ) + } + }; + + // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. + let tokenizer_config: Option = if let Some(filename) = tokenizer_config_path + { + HubTokenizerConfig::from_file(filename) + } else { + tokenizer_config_filename.and_then(HubTokenizerConfig::from_file) + }; + let tokenizer_config = tokenizer_config.unwrap_or_else(|| { + tracing::warn!("Could not find tokenizer config locally and no API specified"); + HubTokenizerConfig::default() + }); + + let tokenizer: Option = tokenizer_filename.and_then(|filename| { + let mut tokenizer = Tokenizer::from_file(filename).ok(); + if let Some(tokenizer) = &mut tokenizer { + if let Some(class) = &tokenizer_config.tokenizer_class { + if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{ + if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) { + tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205"); + tokenizer.with_post_processor(post_processor); + } + } + } + } + tokenizer + }); + + let config: Option = config_filename.and_then(|filename| { + std::fs::read_to_string(filename) + .ok() + .as_ref() + .and_then(|c| { + let config: Result = serde_json::from_str(c); + if let Err(err) = &config { + tracing::warn!("Could not parse config {err:?}"); + } + config.ok() + }) + }); + let model_info = model_info.unwrap_or_else(|| HubModelInfo { + model_id: tokenizer_name.to_string(), + sha: None, + pipeline_tag: None, + }); + + let processor_config = processor_config_filename + .and_then(HubProcessorConfig::from_file) + .unwrap_or_default(); + + let preprocessor_config: Option = + preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file); + + tracing::info!("Using config {config:?}"); + if tokenizer.is_none() { + tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); + tracing::warn!("Rust input length validation and truncation is disabled"); + } + + // Only send usage stats when TGI is run in container and the function returns Some + let is_container = matches!(usage_stats::is_container(), Ok(true)); + + let user_agent = if !disable_usage_stats && is_container { + let reduced_args = usage_stats::Args::new( + config.clone(), + tokenizer_config.tokenizer_class.clone(), + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + // waiting_served_ratio, + // max_batch_prefill_tokens, + // max_batch_total_tokens, + // max_waiting_tokens, + // max_batch_size, + revision.clone(), + validation_workers, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + disable_usage_stats, + disable_crash_reports, + ); + Some(usage_stats::UserAgent::new(reduced_args)) + } else { + None + }; + + if let Some(ref ua) = user_agent { + let start_event = + usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None); + tokio::spawn(async move { + start_event.send().await; + }); + }; + let compat_return_full_text = match &model_info.pipeline_tag { + None => { + tracing::warn!("no pipeline tag found for model {tokenizer_name}"); + true + } + Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation", + }; + let result = start( + backend, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + validation_workers, + api_key, + config, + (tokenizer, tokenizer_config), + (preprocessor_config, processor_config), + hostname, + port, + ngrok, + _ngrok_authtoken, + _ngrok_edge, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + model_info, + compat_return_full_text, + allow_origin, + ) + .await; + + if let Some(ua) = user_agent { + match result { + Ok(_) => { + let stop_event = usage_stats::UsageStatsEvent::new( + ua.clone(), + usage_stats::EventType::Stop, + None, + ); + stop_event.send().await; + Ok(()) + } + Err(e) => { + if !disable_crash_reports { + let error_event = usage_stats::UsageStatsEvent::new( + ua.clone(), + usage_stats::EventType::Error, + Some(e.to_string()), + ); + error_event.send().await; + } else { + let unknow_error_event = usage_stats::UsageStatsEvent::new( + ua.clone(), + usage_stats::EventType::Error, + Some("unknow_error".to_string()), + ); + unknow_error_event.send().await; + } + Err(e) + } + } + } else { + result + } +} + +#[allow(clippy::too_many_arguments)] +async fn start( + backend: impl Backend + Send + Sync + 'static, + max_concurrent_requests: usize, + max_best_of: usize, + max_stop_sequences: usize, + max_top_n_tokens: u32, + max_input_tokens: usize, + max_total_tokens: usize, + validation_workers: usize, + api_key: Option, + config: Option, + (tokenizer, tokenizer_config): (Option, HubTokenizerConfig), + (preprocessor_config, processor_config): (Option, HubProcessorConfig), + hostname: String, + port: u16, + ngrok: bool, + _ngrok_authtoken: Option, + _ngrok_edge: Option, + messages_api_enabled: bool, + disable_grammar_support: bool, + max_client_batch_size: usize, + model_info: HubModelInfo, + compat_return_full_text: bool, + allow_origin: Option, +) -> Result<(), WebServerError> { + // Determine the server port based on the feature and environment variable. + let port = if cfg!(feature = "google") { + std::env::var("AIP_HTTP_PORT") + .map(|aip_http_port| aip_http_port.parse::().unwrap_or(port)) + .unwrap_or(port) + } else { + port + }; + + let addr = match hostname.parse() { + Ok(ip) => SocketAddr::new(ip, port), + Err(_) => { + tracing::warn!("Invalid hostname, defaulting to 0.0.0.0"); + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port) + } + }; + + // Create state let validation = Validation::new( validation_workers, tokenizer, @@ -1662,11 +1857,11 @@ pub async fn run( max_top_n_tokens, max_input_tokens, max_total_tokens, - grammar_support, + disable_grammar_support, ); let infer = Infer::new( - scheduler, + backend, validation, max_concurrent_requests, tokenizer_config, @@ -1703,8 +1898,8 @@ pub async fn run( let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size")); let batch_size_buckets: Vec = (0..1024).map(|x| (x + 1) as f64).collect(); // Speculated tokens buckets - let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens")); - let skipped_buckets: Vec = (0..shard_info.speculate + 1).map(|x| x as f64).collect(); + // let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens")); + // let skipped_buckets: Vec = (0..shard_info.speculate + 1).map(|x| x as f64).collect(); // Prometheus handler let builder = PrometheusBuilder::new() @@ -1717,9 +1912,9 @@ pub async fn run( .set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets) .unwrap() .set_buckets_for_metric(batch_size_matcher, &batch_size_buckets) - .unwrap() - .set_buckets_for_metric(skipped_matcher, &skipped_buckets) .unwrap(); + // .set_buckets_for_metric(skipped_matcher, &skipped_buckets) + // .unwrap(); let prom_handle = builder .install_recorder() .expect("failed to install metrics recorder"); @@ -1735,18 +1930,18 @@ pub async fn run( let info = Info { model_id: model_info.model_id, model_sha: model_info.sha, - model_dtype: shard_info.dtype, - model_device_type: shard_info.device_type, + // model_dtype: shard_info.dtype, + // model_device_type: shard_info.device_type, model_pipeline_tag: model_info.pipeline_tag, max_concurrent_requests, max_best_of, max_stop_sequences, max_input_tokens, max_total_tokens, - waiting_served_ratio, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, + // waiting_served_ratio, + // max_batch_total_tokens, + // max_waiting_tokens, + // max_batch_size, validation_workers, max_client_batch_size, router: env!("CARGO_PKG_NAME"), @@ -1907,7 +2102,6 @@ pub async fn run( // add layers after routes app = app .layer(Extension(info)) - .layer(Extension(health_ext.clone())) .layer(Extension(compat_return_full_text)) .layer(Extension(infer)) .layer(Extension(compute_type)) @@ -1945,6 +2139,68 @@ pub async fn run( Ok(()) } +/// get model info from the Huggingface Hub +pub async fn get_hub_model_info(api: &ApiRepo) -> Option { + let response = api.info_request().send().await.ok()?; + + if response.status().is_success() { + let hub_model_info: HubModelInfo = + serde_json::from_str(&response.text().await.ok()?).ok()?; + if let Some(sha) = &hub_model_info.sha { + tracing::info!( + "Serving revision {sha} of model {}", + hub_model_info.model_id + ); + } + Some(hub_model_info) + } else { + None + } +} + +/// get base tokenizer +pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option { + let config_filename = api_repo.get("config.json").await.ok()?; + + // Open the file in read-only mode with buffer. + let file = File::open(config_filename).ok()?; + let reader = BufReader::new(file); + + // Read the JSON contents of the file as an instance of `User`. + let config: serde_json::Value = serde_json::from_reader(reader).ok()?; + + if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") { + let api_base_repo = api.repo(Repo::with_revision( + base_model_id.to_string(), + RepoType::Model, + "main".to_string(), + )); + + api_base_repo.get("tokenizer.json").await.ok() + } else { + None + } +} + +/// get tokenizer_config from the Huggingface Hub +pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option { + let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?; + + // Open the file in read-only mode with buffer. + let file = File::open(tokenizer_config_filename).ok()?; + let reader = BufReader::new(file); + + // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. + let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader) + .map_err(|e| { + tracing::warn!("Unable to parse tokenizer config: {}", e); + e + }) + .ok()?; + + Some(tokenizer_config) +} + /// Shutdown signal handler async fn shutdown_signal() { let ctrl_c = async { @@ -2008,16 +2264,77 @@ impl From for Event { #[derive(Debug, Error)] pub enum WebServerError { - #[error("Unable to connect to the Python model shards: {0}")] - Connection(ClientError), - #[error("Unable to clear the Python model shards cache: {0}")] - Cache(ClientError), - #[error("Unable to get the Python model shards info: {0}")] - Info(ClientError), - #[error("Unable to warmup the Python model shards: {0}")] - Warmup(ClientError), - #[error("Not enough memory to handle `max_total_tokens={0}`")] - NotEnoughMemory(usize), #[error("Axum error: {0}")] Axum(#[from] axum::BoxError), } + +/// Create a post_processor for the LlamaTokenizer +fn create_post_processor( + tokenizer: &Tokenizer, + tokenizer_config: &HubTokenizerConfig, +) -> Result { + let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true); + let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false); + + let bos_token = tokenizer_config.bos_token.as_ref(); + let eos_token = tokenizer_config.eos_token.as_ref(); + + if add_bos_token && bos_token.is_none() { + panic!("add_bos_token = true but bos_token is None"); + } + + if add_eos_token && eos_token.is_none() { + panic!("add_eos_token = true but eos_token is None"); + } + + let mut single = Vec::new(); + let mut pair = Vec::new(); + let mut special_tokens = Vec::new(); + + if add_bos_token { + if let Some(bos) = bos_token { + let bos_token_id = tokenizer + .token_to_id(bos.as_str()) + .expect("Should have found the bos token id"); + special_tokens.push((bos.as_str(), bos_token_id)); + single.push(format!("{}:0", bos.as_str())); + pair.push(format!("{}:0", bos.as_str())); + } + } + + single.push("$A:0".to_string()); + pair.push("$A:0".to_string()); + + if add_eos_token { + if let Some(eos) = eos_token { + let eos_token_id = tokenizer + .token_to_id(eos.as_str()) + .expect("Should have found the eos token id"); + special_tokens.push((eos.as_str(), eos_token_id)); + single.push(format!("{}:0", eos.as_str())); + pair.push(format!("{}:0", eos.as_str())); + } + } + + if add_bos_token { + if let Some(bos) = bos_token { + pair.push(format!("{}:1", bos.as_str())); + } + } + + pair.push("$B:1".to_string()); + + if add_eos_token { + if let Some(eos) = eos_token { + pair.push(format!("{}:1", eos.as_str())); + } + } + + let post_processor = TemplateProcessing::builder() + .try_single(single)? + .try_pair(pair)? + .special_tokens(special_tokens) + .build()?; + + Ok(post_processor) +} diff --git a/router/src/usage_stats.rs b/router/src/usage_stats.rs index 8559ae90..fa9f3637 100644 --- a/router/src/usage_stats.rs +++ b/router/src/usage_stats.rs @@ -78,11 +78,11 @@ pub struct Args { max_top_n_tokens: u32, max_input_tokens: usize, max_total_tokens: usize, - waiting_served_ratio: f32, - max_batch_prefill_tokens: u32, - max_batch_total_tokens: Option, - max_waiting_tokens: usize, - max_batch_size: Option, + // waiting_served_ratio: f32, + // max_batch_prefill_tokens: u32, + // max_batch_total_tokens: Option, + // max_waiting_tokens: usize, + // max_batch_size: Option, revision: Option, validation_workers: usize, messages_api_enabled: bool, @@ -103,11 +103,11 @@ impl Args { max_top_n_tokens: u32, max_input_tokens: usize, max_total_tokens: usize, - waiting_served_ratio: f32, - max_batch_prefill_tokens: u32, - max_batch_total_tokens: Option, - max_waiting_tokens: usize, - max_batch_size: Option, + // waiting_served_ratio: f32, + // max_batch_prefill_tokens: u32, + // max_batch_total_tokens: Option, + // max_waiting_tokens: usize, + // max_batch_size: Option, revision: Option, validation_workers: usize, messages_api_enabled: bool, @@ -125,11 +125,11 @@ impl Args { max_top_n_tokens, max_input_tokens, max_total_tokens, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, + // waiting_served_ratio, + // max_batch_prefill_tokens, + // max_batch_total_tokens, + // max_waiting_tokens, + // max_batch_size, revision, validation_workers, messages_api_enabled, diff --git a/router/src/validation.rs b/router/src/validation.rs index 1913a2ce..3d1a4103 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -5,13 +5,12 @@ use crate::{ GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor, }; use base64::{engine::general_purpose::STANDARD, Engine}; -use image::{io::Reader as ImageReader, ImageFormat}; +use image::{ImageFormat, ImageReader}; use jsonschema::{Draft, JSONSchema}; use rand::{thread_rng, Rng}; use serde_json::Value; use std::io::Cursor; use std::iter; -use text_generation_client::{Chunk, Image, InputChunk}; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokio::sync::mpsc; @@ -96,7 +95,7 @@ impl Validation { &self, inputs: String, truncate: Option, - ) -> Result)>, ValidationError> { + ) -> Result)>, ValidationError> { // If we have a fast tokenizer if let Some(sender) = &self.sender { // Create response channel @@ -122,7 +121,7 @@ impl Validation { inputs: String, truncate: Option, max_new_tokens: Option, - ) -> Result<(Vec, usize, u32), ValidationError> { + ) -> Result<(Vec, usize, u32), ValidationError> { // If we have a fast tokenizer if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { // Create response channel @@ -181,11 +180,7 @@ impl Validation { input_length = input_length.saturating_sub(max_new_tokens as usize); } - Ok(( - vec![Chunk::Text(inputs).into()], - input_length, - max_new_tokens, - )) + Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens)) } } @@ -589,7 +584,7 @@ fn prepare_input( tokenizer: &Tokenizer, config: Option<&Config>, preprocessor_config: Option<&HubPreprocessorConfig>, -) -> Result<(tokenizers::Encoding, Vec), ValidationError> { +) -> Result<(tokenizers::Encoding, Vec), ValidationError> { use Config::*; static RE: Lazy = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); let (tokenizer_query, input_chunks) = match config { @@ -601,16 +596,16 @@ fn prepare_input( let chunk_start = chunk.start(); let chunk_end = chunk.end(); if chunk_start != start { - input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into()); + input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string())); tokenizer_query.push_str(&inputs[start..chunk_start]); } let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; - input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); + input_chunks.push(Chunk::Image(Image { data, mimetype })); tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width)); start = chunk_end; } if start != inputs.len() { - input_chunks.push(Chunk::Text(inputs[start..].to_string()).into()); + input_chunks.push(Chunk::Text(inputs[start..].to_string())); tokenizer_query.push_str(&inputs[start..]); } @@ -618,7 +613,7 @@ fn prepare_input( (tokenizer_query, input_chunks) } - _ => (inputs.clone(), vec![Chunk::Text(inputs).into()]), + _ => (inputs.clone(), vec![Chunk::Text(inputs)]), }; // Get the number of tokens in the input @@ -631,18 +626,51 @@ fn prepare_input( type TokenizerRequest = ( (String, Option), - oneshot::Sender), ValidationError>>, + oneshot::Sender), ValidationError>>, Span, ); +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct Image { + pub data: Vec, + pub mimetype: String, +} + +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum Chunk { + Text(String), + Image(Image), +} + +/// Convert input chunks to a stringly-typed input for backwards +/// compat for backends that haven't implemented chunked inputs. +pub trait ChunksToString { + /// Convert chunks to string. + fn chunks_to_string(&self) -> String; +} + +impl ChunksToString for Vec { + fn chunks_to_string(&self) -> String { + let mut output = String::new(); + self.iter().for_each(|c| match &c { + Chunk::Text(text) => output.push_str(text), + Chunk::Image(Image { data, mimetype }) => { + let encoded = STANDARD.encode(data); + output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded)) + } + }); + output + } +} + #[derive(Debug, Clone)] -pub(crate) enum ValidGrammar { +pub enum ValidGrammar { Json(String), Regex(String), } #[derive(Debug, Clone)] -pub(crate) struct ValidParameters { +pub struct ValidParameters { /// / exponential scaling output probability distribution pub temperature: f32, /// / restricting to the k highest probability elements @@ -666,7 +694,7 @@ pub(crate) struct ValidParameters { } #[derive(Debug, Clone)] -pub(crate) struct ValidStoppingParameters { +pub struct ValidStoppingParameters { /// / Maximum number of generated tokens pub max_new_tokens: u32, /// / Optional stopping sequences @@ -677,8 +705,8 @@ pub(crate) struct ValidStoppingParameters { } #[derive(Debug, Clone)] -pub(crate) struct ValidGenerateRequest { - pub inputs: Vec, +pub struct ValidGenerateRequest { + pub inputs: Vec, pub input_length: u32, pub truncate: u32, pub decoder_input_details: bool, @@ -750,6 +778,8 @@ pub enum ValidationError { InvalidImageContent(String), #[error("Could not fetch image: {0}")] FailedFetchImage(#[from] reqwest::Error), + #[error("{0} modality is not supported")] + UnsupportedModality(&'static str), } #[cfg(test)] diff --git a/update_doc.py b/update_doc.py index bfa7e4e9..428d4452 100644 --- a/update_doc.py +++ b/update_doc.py @@ -167,22 +167,24 @@ def check_openapi(check: bool): else: os.rename(tmp_filename, filename) print("OpenAPI documentation updated.") - errors = subprocess.run( + p = subprocess.run( [ - "swagger-cli", + "redocly", # allow for trailing whitespace since it's not significant # and the precommit hook will remove it - "validate", + "lint", filename, ], capture_output=True, - ).stderr.decode("utf-8") + ) + errors = p.stderr.decode("utf-8") # The openapi specs fails on `exclusive_minimum` which is expected to be a boolean where # utoipa outputs a value instead: https://github.com/juhaku/utoipa/issues/969 - if not errors.startswith("Swagger schema validation failed."): + print(errors) + if p.returncode != 0: print(errors) raise Exception( - f"OpenAPI documentation is invalid, `swagger-cli validate` showed some error:\n {errors}" + f"OpenAPI documentation is invalid, `redocly lint {filename}` showed some error:\n {errors}" ) return True