diff --git a/.devcontainer/Dockerfile.trtllm b/.devcontainer/Dockerfile.trtllm new file mode 100644 index 00000000..e69de29b diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 00000000..e69de29b diff --git a/.dockerignore b/.dockerignore index c69283ec..0b495924 100644 --- a/.dockerignore +++ b/.dockerignore @@ -2,3 +2,6 @@ aml target server/transformers server/flash-attention +cmake-build-debug/ +cmake-build-release/ +Dockerfile* diff --git a/.gitignore b/.gitignore index e9ad1808..4270a1ae 100644 --- a/.gitignore +++ b/.gitignore @@ -3,9 +3,12 @@ target router/tokenizer.json *__pycache__* +backends/v2/src/client/pb +backends/v3/src/client/pb + # ROCm auto-generated files *.hip -server/exllamav2_kernels/exllamav2_kernels/hip/ +server/exllamav2 server/exllama_kernels/exllama_kernels/hip/ server/exllama_kernels/exllama_kernels/hip_func/ *_hip.cuh @@ -14,3 +17,7 @@ server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp data/ load_tests/*.json +server/fbgemmm + +.direnv/ +.venv/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 45bc07a5..0c8b6885 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace - exclude: docs/source/basic_tutorials/launcher.md + exclude: docs/source/reference/launcher.md - repo: https://github.com/psf/black rev: 24.2.0 hooks: @@ -13,6 +13,11 @@ repos: - repo: https://github.com/doublify/pre-commit-rust rev: v1.0 hooks: - - id: fmt - id: cargo-check + - id: fmt - id: clippy +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.0 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] diff --git a/.redocly.lint-ignore.yaml b/.redocly.lint-ignore.yaml new file mode 100644 index 00000000..fb02c00f --- /dev/null +++ b/.redocly.lint-ignore.yaml @@ -0,0 +1,82 @@ +# This file instructs Redocly's linter to ignore the rules contained for specific parts of your API. +# See https://redoc.ly/docs/cli/ for more information. +docs/openapi.json: + no-empty-servers: + - '#/openapi' + spec: + - >- + #/components/schemas/GenerateParameters/properties/best_of/exclusiveMinimum + - >- + #/components/schemas/GenerateParameters/properties/frequency_penalty/exclusiveMinimum + - '#/components/schemas/GenerateParameters/properties/grammar/nullable' + - >- + #/components/schemas/GenerateParameters/properties/repetition_penalty/exclusiveMinimum + - '#/components/schemas/GenerateParameters/properties/seed/exclusiveMinimum' + - >- + #/components/schemas/GenerateParameters/properties/temperature/exclusiveMinimum + - '#/components/schemas/GenerateParameters/properties/top_k/exclusiveMinimum' + - >- + #/components/schemas/GenerateParameters/properties/top_n_tokens/exclusiveMinimum + - '#/components/schemas/GenerateParameters/properties/top_p/exclusiveMinimum' + - >- + #/components/schemas/GenerateParameters/properties/typical_p/exclusiveMinimum + - '#/components/schemas/GenerateResponse/properties/details/nullable' + - '#/components/schemas/StreamResponse/properties/details/nullable' + - '#/components/schemas/ChatRequest/properties/response_format/nullable' + - '#/components/schemas/ChatRequest/properties/stream_options/nullable' + - '#/components/schemas/ChatRequest/properties/tool_choice/nullable' + - '#/components/schemas/ToolChoice/nullable' + - '#/components/schemas/ChatCompletionComplete/properties/logprobs/nullable' + - '#/components/schemas/ChatCompletionChunk/properties/usage/nullable' + - '#/components/schemas/ChatCompletionChoice/properties/logprobs/nullable' + no-invalid-media-type-examples: + - '#/paths/~1/post/responses/422/content/application~1json/example' + - '#/paths/~1/post/responses/424/content/application~1json/example' + - '#/paths/~1/post/responses/429/content/application~1json/example' + - '#/paths/~1/post/responses/500/content/application~1json/example' + - '#/paths/~1generate/post/responses/422/content/application~1json/example' + - '#/paths/~1generate/post/responses/424/content/application~1json/example' + - '#/paths/~1generate/post/responses/429/content/application~1json/example' + - '#/paths/~1generate/post/responses/500/content/application~1json/example' + - >- + #/paths/~1generate_stream/post/responses/422/content/text~1event-stream/example + - >- + #/paths/~1generate_stream/post/responses/424/content/text~1event-stream/example + - >- + #/paths/~1generate_stream/post/responses/429/content/text~1event-stream/example + - >- + #/paths/~1generate_stream/post/responses/500/content/text~1event-stream/example + - '#/paths/~1tokenize/post/responses/404/content/application~1json/example' + - >- + #/paths/~1v1~1chat~1completions/post/responses/422/content/application~1json/example + - >- + #/paths/~1v1~1chat~1completions/post/responses/424/content/application~1json/example + - >- + #/paths/~1v1~1chat~1completions/post/responses/429/content/application~1json/example + - >- + #/paths/~1v1~1chat~1completions/post/responses/500/content/application~1json/example + - >- + #/paths/~1v1~1completions/post/responses/422/content/application~1json/example + - >- + #/paths/~1v1~1completions/post/responses/424/content/application~1json/example + - >- + #/paths/~1v1~1completions/post/responses/429/content/application~1json/example + - >- + #/paths/~1v1~1completions/post/responses/500/content/application~1json/example + operation-4xx-response: + - '#/paths/~1health/get/responses' + - '#/paths/~1info/get/responses' + - '#/paths/~1metrics/get/responses' + no-unused-components: + - '#/components/schemas/Completion' + security-defined: + - '#/paths/~1/post' + - '#/paths/~1generate/post' + - '#/paths/~1generate_stream/post' + - '#/paths/~1health/get' + - '#/paths/~1info/get' + - '#/paths/~1metrics/get' + - '#/paths/~1tokenize/post' + - '#/paths/~1v1~1chat~1completions/post' + - '#/paths/~1v1~1completions/post' + - '#/paths/~1v1~1models/get' diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 00000000..b23f3150 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,133 @@ + +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, caste, color, religion, or sexual +identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the overall + community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or advances of + any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email address, + without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +feedback@huggingface.co. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of +actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or permanent +ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the +community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.1, available at +[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder][Mozilla CoC]. + +For answers to common questions about this code of conduct, see the FAQ at +[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at +[https://www.contributor-covenant.org/translations][translations]. + +[homepage]: https://www.contributor-covenant.org +[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html +[Mozilla CoC]: https://github.com/mozilla/diversity +[FAQ]: https://www.contributor-covenant.org/faq +[translations]: https://www.contributor-covenant.org/translations diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..d541e47f --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,120 @@ + + +# Contribute to text-generation-inference + +Everyone is welcome to contribute, and we value everybody's contribution. Code +contributions are not the only way to help the community. Answering questions, helping +others, and improving the documentation are also immensely valuable. + +It also helps us if you spread the word! Reference the library in blog posts +about the awesome projects it made possible, shout out on Twitter every time it has +helped you, or simply ⭐️ the repository to say thank you. + +However you choose to contribute, please be mindful and respect our +[code of conduct](https://github.com/huggingface/text-generation-inference/blob/main/CODE_OF_CONDUCT.md). + +**This guide was heavily inspired by the awesome [scikit-learn guide to contributing](https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md).** + +## Ways to contribute + +There are several ways you can contribute to text-generation-inference. + +* Fix outstanding issues with the existing code. +* Submit issues related to bugs or desired new features. +* Contribute to the examples or to the documentation. + +> All contributions are equally valuable to the community. 🥰 + +## Fixing outstanding issues + +If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request) and open +a Pull Request! + +## Submitting a bug-related issue or feature request + +Do your best to follow these guidelines when submitting a bug-related issue or a feature +request. It will make it easier for us to come back to you quickly and with good +feedback. + +### Did you find a bug? + +The text-generation-inference library is robust and reliable thanks to users who report the problems they encounter. + +Before you report an issue, we would really appreciate it if you could **make sure the bug was not +already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the +library itself, and not your code. + +Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so +we can quickly resolve it: + +* Your **OS type and version**, as well as your environment versions (versions of rust, python, and dependencies). +* A short, self-contained, code snippet that allows us to reproduce the bug. +* The *full* traceback if an exception is raised. +* Attach any other additional information, like screenshots, you think may help. + +To get the OS and software versions automatically, you can re-run the launcher with the `--env` flag: + +```bash +text-generation-launcher --env +``` + +This will precede the launch of the model with the information relative to your environment. We recommend pasting +that in your issue report. + +### Do you want a new feature? + +If there is a new feature you'd like to see in text-generation-inference, please open an issue and describe: + +1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it + a feature related to something you need for a project? Is it something you worked on and think it could benefit + the community? + + Whatever it is, we'd love to hear about it! + +2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better + we'll be able to help you. +3. Provide a *code snippet* that demonstrates the feature's usage. +4. If the feature is related to a paper, please include a link. + +If your issue is well written we're already 80% of the way there by the time you create it. + +We have added [templates](https://github.com/huggingface/text-generation-inference/tree/main/.github/ISSUE_TEMPLATE) +to help you get started with your issue. + +## Do you want to implement a new model? + +New models are constantly released and if you want to implement a new model, please provide the following information: + +* A short description of the model and a link to the paper. +* Link to the implementation if it is open-sourced. +* Link to the model weights if they are available. + +If you are willing to contribute the model yourself, let us know so we can help you add it to text-generation-inference! + +## Do you want to add documentation? + +We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know +how the documentation can be improved such as typos and any content that is missing, unclear or inaccurate. We'll be +happy to make the changes or help you make a contribution if you're interested! + +## I want to become a maintainer of the project. How do I get there? + +TGI is a project led and managed by Hugging Face as it powers our internal services. However, we are happy to have +motivated individuals from other organizations join us as maintainers with the goal of making TGI the best inference +service. + +If you are such an individual (or organization), please reach out to us and let's collaborate. diff --git a/Cargo.lock b/Cargo.lock index 2e75fe8f..27499cd4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "addr2line" -version = "0.22.0" +version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678" +checksum = "f5fb1d8e4442bd405fdfd1dacb42792696b0cf9cb15882e5d097b742a676d375" dependencies = [ "gimli", ] @@ -17,6 +17,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + [[package]] name = "ahash" version = "0.8.11" @@ -47,10 +53,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1" [[package]] -name = "anstream" -version = "0.6.14" +name = "allocator-api2" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "418c75fa768af9c03be99d17643f93f79bbba589895012a80e3452a19ddda15b" +checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" + +[[package]] +name = "anstream" +version = "0.6.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526" dependencies = [ "anstyle", "anstyle-parse", @@ -63,33 +75,33 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.7" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b" +checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" [[package]] name = "anstyle-parse" -version = "0.2.4" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c03a11a9034d92058ceb6ee011ce58af4a9bf61491aa7e1e59ecd24bd40d22d4" +checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.0.3" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a64c907d4e79225ac72e2a354c9ce84d50ebb4586dee56c82b3ee73004f537f5" +checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a" dependencies = [ "windows-sys 0.52.0", ] [[package]] name = "anstyle-wincon" -version = "3.0.3" +version = "3.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61a38449feb7068f52bb06c12759005cf459ee52bb4adc1d5a7c4322d716fb19" +checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8" dependencies = [ "anstyle", "windows-sys 0.52.0", @@ -97,9 +109,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.86" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" +checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" [[package]] name = "arbitrary" @@ -121,14 +133,14 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.79", ] [[package]] name = "arrayvec" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "async-rustls" @@ -160,25 +172,42 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.79", ] [[package]] name = "async-trait" -version = "0.1.80" +version = "0.1.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" +checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.79", +] + +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + +[[package]] +name = "atty" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +dependencies = [ + "hermit-abi 0.1.19", + "libc", + "winapi", ] [[package]] name = "autocfg" -version = "1.3.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "av1-grain" @@ -226,6 +255,33 @@ dependencies = [ "slotmap", ] +[[package]] +name = "aws-lc-rs" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f95446d919226d587817a7d21379e6eb099b97b45110a7f272a444ca5c54070" +dependencies = [ + "aws-lc-sys", + "mirai-annotations", + "paste", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3ddc4a5b231dd6958b140ff3151b6412b3f4321fab354f399eec8f14b06df62" +dependencies = [ + "bindgen", + "cc", + "cmake", + "dunce", + "fs_extra", + "libc", + "paste", +] + [[package]] name = "axum" version = "0.6.20" @@ -233,13 +289,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf" dependencies = [ "async-trait", - "axum-core", + "axum-core 0.3.4", "bitflags 1.3.2", "bytes", "futures-util", - "http", - "http-body", - "hyper", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.30", "itoa", "matchit", "memchr", @@ -251,13 +307,47 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_urlencoded", - "sync_wrapper", + "sync_wrapper 0.1.2", "tokio", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", ] +[[package]] +name = "axum" +version = "0.7.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae" +dependencies = [ + "async-trait", + "axum-core 0.4.5", + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.4.1", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper 1.0.1", + "tokio", + "tower 0.5.1", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "axum-core" version = "0.3.4" @@ -267,8 +357,8 @@ dependencies = [ "async-trait", "bytes", "futures-util", - "http", - "http-body", + "http 0.2.12", + "http-body 0.4.6", "mime", "rustversion", "tower-layer", @@ -276,36 +366,57 @@ dependencies = [ ] [[package]] -name = "axum-tracing-opentelemetry" -version = "0.14.1" +name = "axum-core" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06985105829f176e9a3f113b1c71cc24e08f600ef0df4e70cd90d144f889e19f" +checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" dependencies = [ - "axum", + "async-trait", + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper 1.0.1", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-tracing-opentelemetry" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdad298231394729042d1f155b93f9fdf0b5ee1aea0b62404c4d7341f7d8fe08" +dependencies = [ + "axum 0.7.7", "futures-core", "futures-util", - "http", - "opentelemetry", + "http 1.1.0", + "opentelemetry 0.21.0", "pin-project-lite", - "tower", + "tower 0.4.13", "tracing", - "tracing-opentelemetry", + "tracing-opentelemetry 0.22.0", "tracing-opentelemetry-instrumentation-sdk", ] [[package]] name = "backtrace" -version = "0.3.73" +version = "0.3.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cc23269a4f8976d0a4d2e7109211a419fe30e8d88d677cd60b6bc79c5732e0a" +checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" dependencies = [ "addr2line", - "cc", "cfg-if", "libc", - "miniz_oxide", + "miniz_oxide 0.8.0", "object", "rustc-demangle", + "windows-targets 0.52.6", ] [[package]] @@ -326,6 +437,29 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "bindgen" +version = "0.69.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a00dc851838a2120612785d195287475a3ac45514741da670b735818822129a0" +dependencies = [ + "bitflags 2.6.0", + "cexpr", + "clang-sys", + "itertools 0.12.1", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn 2.0.79", + "which", +] + [[package]] name = "bit-set" version = "0.5.3" @@ -355,15 +489,15 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" [[package]] name = "bitstream-io" -version = "2.3.0" +version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c12d1856e42f0d817a835fe55853957c85c8c8a470114029143d3f12671446e" +checksum = "b81e1519b0d82120d2fd469d5bfb2919a9361c48b02d82d04befc1cdd2002452" [[package]] name = "block-buffer" @@ -376,9 +510,9 @@ dependencies = [ [[package]] name = "built" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6a6c0b39c38fd754ac338b00a88066436389c0f029da5d37d1e01091d9b7c17" +checksum = "236e6289eda5a812bc6b53c3b024039382a2895fbbeef2d748b2931546d392c4" [[package]] name = "bumpalo" @@ -394,9 +528,9 @@ checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce" [[package]] name = "bytemuck" -version = "1.16.0" +version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78834c15cb5d5efe3452d58b1e8ba890dd62d21907f867f383358198e56ebca5" +checksum = "94bbb0ad554ad961ddc5da507a12a29b14e4ae5bda06b19f575a3e6079d2e2ae" [[package]] name = "byteorder" @@ -412,15 +546,15 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" [[package]] name = "bytes" -version = "1.6.0" +version = "1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" +checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" [[package]] name = "camino" -version = "1.1.7" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0ec6b951b160caa93cc0c7b209e5a3bff7aae9062213451ac99493cd844c239" +checksum = "8b96ec4966b5813e2c0507c1f86115c8c5abaadc3980879c3424042a02fd1ad3" dependencies = [ "serde", ] @@ -455,14 +589,38 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" [[package]] -name = "cc" -version = "1.0.99" +name = "cast" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96c51067fd44124faa7f870b4b1c969379ad32b2ba805aa959430ceaa384f695" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + +[[package]] +name = "castaway" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0abae9be0aaf9ea96a3b1b8b1b55c602ca751eba1b1500220cea4ecbafe7c0d5" +dependencies = [ + "rustversion", +] + +[[package]] +name = "cc" +version = "1.1.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9540e661f81799159abee814118cc139a2004b3a3aa3ea37724a1b66530b90e0" dependencies = [ "jobserver", "libc", - "once_cell", + "shlex", +] + +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", ] [[package]] @@ -488,10 +646,38 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" [[package]] -name = "clap" -version = "4.5.7" +name = "cfg_aliases" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5db83dced34638ad474f39f250d7fea9598bdd239eaced1bdf45d597da0f433f" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + +[[package]] +name = "clap" +version = "2.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c" +dependencies = [ + "bitflags 1.3.2", + "textwrap", + "unicode-width", +] + +[[package]] +name = "clap" +version = "4.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0956a43b323ac1afaffc053ed5c4b7c1f1800bacd1683c353aabbb752515dd3" dependencies = [ "clap_builder", "clap_derive", @@ -499,9 +685,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.7" +version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7e204572485eb3fbf28f871612191521df159bc3e15a9f5064c66dba3a8c05f" +checksum = "4d72166dd41634086d5803a47eb71ae740e61d84709c36f3c34110173db3961b" dependencies = [ "anstream", "anstyle", @@ -511,21 +697,40 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.5" +version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c780290ccf4fb26629baa7a1081e68ced113f1d3ec302fa5948f1c381ebf06c6" +checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.79", ] [[package]] name = "clap_lex" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70" +checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" + +[[package]] +name = "cmake" +version = "0.1.51" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb1e43aa7fd152b1f968787f7dbcdeb306d1867ff373c69955211876c053f91a" +dependencies = [ + "cc", +] + +[[package]] +name = "codespan-reporting" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e" +dependencies = [ + "termcolor", + "unicode-width", +] [[package]] name = "color_quant" @@ -535,9 +740,23 @@ checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" [[package]] name = "colorchoice" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" +checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" + +[[package]] +name = "compact_str" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6050c3a16ddab2e412160b31f2c871015704239bca62f72f6e5f0be631d3f644" +dependencies = [ + "castaway", + "cfg-if", + "itoa", + "rustversion", + "ryu", + "static_assertions", +] [[package]] name = "console" @@ -564,15 +783,15 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.6" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "cpufeatures" -version = "0.2.12" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0" dependencies = [ "libc", ] @@ -586,6 +805,42 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b01d6de93b2b6c65e17c634a26653a29d107b3c98c607c765bf38d041531cd8f" +dependencies = [ + "atty", + "cast", + "clap 2.34.0", + "criterion-plot", + "csv", + "itertools 0.10.5", + "lazy_static", + "num-traits", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_cbor", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2673cc8207403546f45f5fd319a974b1e6983ad1a3ee7e6041650013be041876" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "crossbeam-channel" version = "0.5.13" @@ -622,15 +877,15 @@ checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" [[package]] name = "crossterm" -version = "0.27.0" +version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df" +checksum = "829d955a0bb380ef178a640b91779e3987da38c9aea133b20614cfed8cdea9c6" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "crossterm_winapi", - "libc", "mio", "parking_lot", + "rustix", "signal-hook", "signal-hook-mio", "winapi", @@ -662,20 +917,85 @@ dependencies = [ ] [[package]] -name = "ctrlc" -version = "3.4.4" +name = "csv" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "672465ae37dc1bc6380a6547a8883d5dd397b0f1faaad4f265726cc7042a5345" +checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" dependencies = [ - "nix", - "windows-sys 0.52.0", + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +dependencies = [ + "memchr", +] + +[[package]] +name = "ctrlc" +version = "3.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90eeab0aa92f3f9b4e87f258c72b139c207d251f9cbc1080a0086b86a8870dd3" +dependencies = [ + "nix 0.29.0", + "windows-sys 0.59.0", +] + +[[package]] +name = "cxx" +version = "1.0.128" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54ccead7d199d584d139148b04b4a368d1ec7556a1d9ea2548febb1b9d49f9a4" +dependencies = [ + "cc", + "cxxbridge-flags", + "cxxbridge-macro", + "link-cplusplus", +] + +[[package]] +name = "cxx-build" +version = "1.0.128" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c77953e99f01508f89f55c494bfa867171ef3a6c8cea03d26975368f2121a5c1" +dependencies = [ + "cc", + "codespan-reporting", + "once_cell", + "proc-macro2", + "quote", + "scratch", + "syn 2.0.79", +] + +[[package]] +name = "cxxbridge-flags" +version = "1.0.128" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65777e06cc48f0cb0152024c77d6cf9e4bdb4408e7b48bea993d42fa0f5b02b6" + +[[package]] +name = "cxxbridge-macro" +version = "1.0.128" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98532a60dedaebc4848cb2cba5023337cc9ea3af16a5b062633fabfd9f18fb60" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", ] [[package]] name = "darling" -version = "0.20.9" +version = "0.20.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83b2eb4d90d12bdda5ed17de686c2acb4c57914f8f921b8da7e112b5a36f3fe1" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" dependencies = [ "darling_core", "darling_macro", @@ -683,27 +1003,27 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.20.9" +version = "0.20.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "622687fe0bac72a04e5599029151f5796111b90f1baaa9b544d807a5e31cd120" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" dependencies = [ "fnv", "ident_case", "proc-macro2", "quote", "strsim", - "syn 2.0.66", + "syn 2.0.79", ] [[package]] name = "darling_macro" -version = "0.20.9" +version = "0.20.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "733cabb43482b1a1b53eee8583c2b9e8684d592215ea83efd305dd31bc2f0178" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn 2.0.66", + "syn 2.0.79", ] [[package]] @@ -717,33 +1037,33 @@ dependencies = [ [[package]] name = "derive_builder" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0350b5cb0331628a5916d6c5c0b72e97393b8b6b03b47a9284f4e7f5a405ffd7" +checksum = "cd33f37ee6a119146a1781d3356a7c26028f83d779b2e04ecd45fdc75c76877b" dependencies = [ "derive_builder_macro", ] [[package]] name = "derive_builder_core" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d48cda787f839151732d396ac69e3473923d54312c070ee21e9effcaa8ca0b1d" +checksum = "7431fa049613920234f22c47fdc33e6cf3ee83067091ea4277a3f8c4587aae38" dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.79", ] [[package]] name = "derive_builder_macro" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b" +checksum = "4abae7035bf79b9877b779505d8cf3749285b80c43941eda66604841889451dc" dependencies = [ "derive_builder_core", - "syn 2.0.66", + "syn 2.0.79", ] [[package]] @@ -756,33 +1076,13 @@ dependencies = [ "crypto-common", ] -[[package]] -name = "dirs" -version = "4.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059" -dependencies = [ - "dirs-sys 0.3.7", -] - [[package]] name = "dirs" version = "5.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225" dependencies = [ - "dirs-sys 0.4.1", -] - -[[package]] -name = "dirs-sys" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6" -dependencies = [ - "libc", - "redox_users", - "winapi", + "dirs-sys", ] [[package]] @@ -798,15 +1098,10 @@ dependencies = [ ] [[package]] -name = "displaydoc" -version = "0.2.4" +name = "dunce" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "487585f4d0c6655fe74905e2504d8ad6908e4db67f744eb140876906c2f3175d" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.66", -] +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" [[package]] name = "easy-cast" @@ -819,9 +1114,9 @@ dependencies = [ [[package]] name = "either" -version = "1.12.0" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3dca9240753cf90908d7e4aac30f630662b02aebaa1b58a3cadabdb23385b58b" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] name = "encode_unicode" @@ -871,9 +1166,9 @@ checksum = "887d93f60543e9a9362ef8a21beedd0a833c5d9610e18c67abe15a5963dcb1a4" dependencies = [ "bit_field", "flume", - "half", + "half 2.4.1", "lebe", - "miniz_oxide", + "miniz_oxide 0.7.4", "rayon-core", "smallvec", "zune-inflate", @@ -891,15 +1186,15 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.1.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" +checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" [[package]] name = "fdeflate" -version = "0.3.4" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f9bfee30e4dedf0ab8b422f03af778d9612b63f502710fc500a334ebe2de645" +checksum = "d8090f921a24b04994d9929e204f50b498a33ea6ba559ffaa05e04f7ee7fb5ab" dependencies = [ "simd-adler32", ] @@ -912,12 +1207,12 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" [[package]] name = "flate2" -version = "1.0.30" +version = "1.0.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" +checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0" dependencies = [ "crc32fast", - "miniz_oxide", + "miniz_oxide 0.8.0", ] [[package]] @@ -981,6 +1276,12 @@ dependencies = [ "num", ] +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "futures" version = "0.3.30" @@ -1037,7 +1338,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.79", ] [[package]] @@ -1114,18 +1415,24 @@ dependencies = [ [[package]] name = "gimli" -version = "0.29.0" +version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" +checksum = "32085ea23f3234fc7846555e85283ba4de91e21016dc0455a16286d87a292d64" + +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "grpc-metadata" version = "0.1.0" dependencies = [ - "opentelemetry", + "opentelemetry 0.20.0", "tonic 0.10.2", "tracing", - "tracing-opentelemetry", + "tracing-opentelemetry 0.21.0", ] [[package]] @@ -1139,14 +1446,39 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", - "indexmap 2.2.6", + "http 0.2.12", + "indexmap 2.5.0", "slab", "tokio", "tokio-util", "tracing", ] +[[package]] +name = "h2" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http 1.1.0", + "indexmap 2.5.0", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "half" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403" + [[package]] name = "half" version = "2.4.1" @@ -1163,20 +1495,15 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" -[[package]] -name = "hashbrown" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ff8ae62cd3a9102e5637afc8452c55acf3844001bd5374e0b0bd7b6616c038" -dependencies = [ - "ahash", -] - [[package]] name = "hashbrown" version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", + "allocator-api2", +] [[package]] name = "heck" @@ -1190,6 +1517,15 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +dependencies = [ + "libc", +] + [[package]] name = "hermit-abi" version = "0.3.9" @@ -1202,7 +1538,7 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732" dependencies = [ - "dirs 5.0.1", + "dirs", "futures", "indicatif", "log", @@ -1217,6 +1553,15 @@ dependencies = [ "ureq", ] +[[package]] +name = "home" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "hostname" version = "0.3.1" @@ -1239,6 +1584,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http-body" version = "0.4.6" @@ -1246,21 +1602,38 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", - "http", + "http 0.2.12", "pin-project-lite", ] [[package]] -name = "http-range-header" -version = "0.3.1" +name = "http-body" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "add0ab9360ddbd88cfeb3bd9574a1d85cfdfa14db10b3e21d3700dbc4328758f" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http 1.1.0", +] + +[[package]] +name = "http-body-util" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" +dependencies = [ + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "pin-project-lite", +] [[package]] name = "httparse" -version = "1.9.3" +version = "1.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0e7a4dd27b9476dc40cb050d3632d3bba3a70ddbff012285f7f8559a1e7e545" +checksum = "0fcc0b4a115bf80b728eb8ea024ad5bd707b615bfed49e0665b6e0f86fd082d9" [[package]] name = "httpdate" @@ -1270,17 +1643,17 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "hyper" -version = "0.14.29" +version = "0.14.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f361cde2f109281a220d4307746cdfd5ee3f410da58a70377762396775634b33" +checksum = "a152ddd61dfaec7273fe8419ab357f33aee0d914c5f4efbf0d96fa749eea5ec9" dependencies = [ "bytes", "futures-channel", "futures-core", "futures-util", - "h2", - "http", - "http-body", + "h2 0.3.26", + "http 0.2.12", + "http-body 0.4.6", "httparse", "httpdate", "itoa", @@ -1292,13 +1665,53 @@ dependencies = [ "want", ] +[[package]] +name = "hyper" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "h2 0.4.6", + "http 1.1.0", + "http-body 1.0.1", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" +dependencies = [ + "futures-util", + "http 1.1.0", + "hyper 1.4.1", + "hyper-util", + "log", + "rustls 0.23.13", + "rustls-native-certs", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", +] + [[package]] name = "hyper-timeout" version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" dependencies = [ - "hyper", + "hyper 0.14.30", "pin-project-lite", "tokio", "tokio-io-timeout", @@ -1311,128 +1724,29 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ "bytes", - "hyper", + "hyper 0.14.30", "native-tls", "tokio", "tokio-native-tls", ] [[package]] -name = "icu_collections" -version = "1.5.0" +name = "hyper-util" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +checksum = "41296eb09f183ac68eec06e03cdbea2e759633d4067b2f6552fc2e009bcad08b" dependencies = [ - "displaydoc", - "yoke", - "zerofrom", - "zerovec", -] - -[[package]] -name = "icu_locid" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" -dependencies = [ - "displaydoc", - "litemap", - "tinystr", - "writeable", - "zerovec", -] - -[[package]] -name = "icu_locid_transform" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" -dependencies = [ - "displaydoc", - "icu_locid", - "icu_locid_transform_data", - "icu_provider", - "tinystr", - "zerovec", -] - -[[package]] -name = "icu_locid_transform_data" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" - -[[package]] -name = "icu_normalizer" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" -dependencies = [ - "displaydoc", - "icu_collections", - "icu_normalizer_data", - "icu_properties", - "icu_provider", - "smallvec", - "utf16_iter", - "utf8_iter", - "write16", - "zerovec", -] - -[[package]] -name = "icu_normalizer_data" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" - -[[package]] -name = "icu_properties" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f8ac670d7422d7f76b32e17a5db556510825b29ec9154f235977c9caba61036" -dependencies = [ - "displaydoc", - "icu_collections", - "icu_locid_transform", - "icu_properties_data", - "icu_provider", - "tinystr", - "zerovec", -] - -[[package]] -name = "icu_properties_data" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" - -[[package]] -name = "icu_provider" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" -dependencies = [ - "displaydoc", - "icu_locid", - "icu_provider_macros", - "stable_deref_trait", - "tinystr", - "writeable", - "yoke", - "zerofrom", - "zerovec", -] - -[[package]] -name = "icu_provider_macros" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.66", + "bytes", + "futures-channel", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "hyper 1.4.1", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", ] [[package]] @@ -1443,24 +1757,22 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" [[package]] name = "idna" -version = "1.0.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4716a3a0933a1d01c2f72450e89596eb51dd34ef3c211ccd875acdf1f8fe47ed" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" dependencies = [ - "icu_normalizer", - "icu_properties", - "smallvec", - "utf8_iter", + "unicode-bidi", + "unicode-normalization", ] [[package]] name = "image" -version = "0.25.1" +version = "0.25.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd54d660e773627692c524beaad361aca785a4f9f5730ce91f42aabe5bce3d11" +checksum = "99314c8a2152b8ddb211f924cdae532d8c5e4c8bb54728e12fff1b0cd5963a10" dependencies = [ "bytemuck", - "byteorder", + "byteorder-lite", "color_quant", "exr", "gif", @@ -1478,12 +1790,12 @@ dependencies = [ [[package]] name = "image-webp" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d730b085583c4d789dfd07fdcf185be59501666a90c97c40162b37e4fdad272d" +checksum = "f79afb8cbee2ef20f59ccd477a218c12a93943d075b492015ecb1bb81f8ee904" dependencies = [ "byteorder-lite", - "thiserror", + "quick-error", ] [[package]] @@ -1504,9 +1816,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.6" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +checksum = "68b900aa2f7301e21c36462b170ee99994de34dff39a4a6a528e80e7376d07e5" dependencies = [ "equivalent", "hashbrown 0.14.5", @@ -1538,11 +1850,21 @@ version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94bd26b1b737bc11f183620072e188d1c6ede67e0e78682228d66b49ec510e17" dependencies = [ - "opentelemetry", + "opentelemetry 0.20.0", "opentelemetry-otlp", "thiserror", "tracing", - "tracing-opentelemetry", + "tracing-opentelemetry 0.21.0", +] + +[[package]] +name = "instability" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b23a0c8dfe501baac4adf6ebbfa6eddf8f0c07f56b058cc1288017e32397846c" +dependencies = [ + "quote", + "syn 2.0.79", ] [[package]] @@ -1562,20 +1884,20 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.79", ] [[package]] name = "ipnet" -version = "2.9.0" +version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" +checksum = "187674a687eed5fe42285b40c6291f9a01517d415fad1c3cbc6a9f778af7fcd4" [[package]] name = "is_terminal_polyfill" -version = "1.70.0" +version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" [[package]] name = "iso8601" @@ -1613,6 +1935,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -1621,9 +1952,9 @@ checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "jobserver" -version = "0.1.31" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2b099aaa34a9751c5bf0878add70444e1ed2dd73f347be99003d4577277de6e" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" dependencies = [ "libc", ] @@ -1636,9 +1967,9 @@ checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0" [[package]] name = "js-sys" -version = "0.3.69" +version = "0.3.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" dependencies = [ "wasm-bindgen", ] @@ -1653,7 +1984,7 @@ dependencies = [ "anyhow", "base64 0.21.7", "bytecount", - "clap", + "clap 4.5.18", "fancy-regex", "fraction", "getrandom", @@ -1675,9 +2006,15 @@ dependencies = [ [[package]] name = "lazy_static" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "lebe" @@ -1687,9 +2024,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" [[package]] name = "libc" -version = "0.2.155" +version = "0.2.159" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" [[package]] name = "libfuzzer-sys" @@ -1702,6 +2039,16 @@ dependencies = [ "once_cell", ] +[[package]] +name = "libloading" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" +dependencies = [ + "cfg-if", + "windows-targets 0.52.6", +] + [[package]] name = "libm" version = "0.2.8" @@ -1714,22 +2061,25 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "libc", ] +[[package]] +name = "link-cplusplus" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d240c6f7e1ba3a28b0249f774e6a9dd0175054b52dfbb61b16eb8505c3785c9" +dependencies = [ + "cc", +] + [[package]] name = "linux-raw-sys" version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" -[[package]] -name = "litemap" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "643cb0b8d4fcc284004d5fd0d67ccf61dfffadb7f75e1e71bc420f4688a3a704" - [[package]] name = "lock_api" version = "0.4.12" @@ -1742,9 +2092,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.21" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "loop9" @@ -1756,12 +2106,12 @@ dependencies = [ ] [[package]] -name = "mach2" -version = "0.4.2" +name = "lru" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19b955cdeb2a02b9117f121ce63aa52d08ade45de53e48fe6a38b39c10f6f709" +checksum = "37ee39891760e7d94734f6f63fedc29a2e4a152f836120753a72503f09fcf904" dependencies = [ - "libc", + "hashbrown 0.14.5", ] [[package]] @@ -1808,7 +2158,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ea1f30cedd69f0a2954655f7188c6a834246d2bcf1e315e2ac40c4b24dc9519" dependencies = [ "cfg-if", - "rayon", ] [[package]] @@ -1818,25 +2167,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] -name = "metrics" -version = "0.21.1" +name = "memoffset" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fde3af1a009ed76a778cb84fdef9e7dbbdf5775ae3e4cc1f434a6a307f6f76c5" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "metrics" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "884adb57038347dfbaf2d5065887b6cf4312330dc8e94bc30a1a839bd79d3261" dependencies = [ "ahash", - "metrics-macros", "portable-atomic", ] [[package]] name = "metrics-exporter-prometheus" -version = "0.12.2" +version = "0.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d4fa7ce7c4862db464a37b0b31d89bca874562f034bd7993895572783d02950" +checksum = "b4f0c8427b39666bf970460908b213ec09b3b350f20c0c2eabcbba51704a08e6" dependencies = [ - "base64 0.21.7", - "hyper", - "indexmap 1.9.3", + "base64 0.22.1", + "http-body-util", + "hyper 1.4.1", + "hyper-rustls", + "hyper-util", + "indexmap 2.5.0", "ipnet", "metrics", "metrics-util", @@ -1846,26 +2206,15 @@ dependencies = [ "tracing", ] -[[package]] -name = "metrics-macros" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38b4faf00617defe497754acde3024865bc143d44a86799b24e191ecff91354f" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.66", -] - [[package]] name = "metrics-util" -version = "0.15.1" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4de2ed6e491ed114b40b732e4d1659a9d53992ebd87490c44a6ffe23739d973e" +checksum = "4259040465c955f9f2f1a4a8a16dc46726169bca0f88e8fb2dbeced487c3e828" dependencies = [ "crossbeam-epoch", "crossbeam-utils", - "hashbrown 0.13.1", + "hashbrown 0.14.5", "metrics", "num_cpus", "quanta", @@ -1880,9 +2229,9 @@ checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] name = "mime_guess" -version = "2.0.4" +version = "2.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" dependencies = [ "mime", "unicase", @@ -1890,10 +2239,22 @@ dependencies = [ [[package]] name = "minijinja" -version = "1.0.12" -source = "git+https://github.com/mitsuhiko/minijinja.git?rev=5cd4efb#5cd4efb9e2639247df275fe6e22a5dbe0ce71b28" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1028b628753a7e1a88fc59c9ba4b02ecc3bc0bd3c7af23df667bc28df9b3310e" dependencies = [ "serde", + "serde_json", +] + +[[package]] +name = "minijinja-contrib" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ffd46ee854be23604a20efd6c9655374fefbe4d44b949dc0f907305d92873a" +dependencies = [ + "minijinja", + "serde", ] [[package]] @@ -1904,26 +2265,42 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87dfd01fe195c66b572b37921ad8803d010623c0aca821bea2302239d155cdae" +checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" dependencies = [ "adler", +] + +[[package]] +name = "miniz_oxide" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +dependencies = [ + "adler2", "simd-adler32", ] [[package]] name = "mio" -version = "0.8.11" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" +checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" dependencies = [ + "hermit-abi 0.3.9", "libc", "log", "wasi", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] +[[package]] +name = "mirai-annotations" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9be0862c1b3f26a88803c4a49de6889c10e608b3ee9344e6ef5b45fb37ad3d1" + [[package]] name = "monostate" version = "0.1.13" @@ -1942,7 +2319,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.79", ] [[package]] @@ -2003,17 +2380,17 @@ dependencies = [ "async-rustls", "async-trait", "awaitdrop", - "axum", + "axum 0.6.20", "base64 0.13.1", "bytes", "futures", "hostname", - "hyper", + "hyper 0.14.30", "muxado", "once_cell", "parking_lot", "regex", - "rustls-pemfile", + "rustls-pemfile 1.0.4", "serde", "serde_json", "thiserror", @@ -2030,9 +2407,21 @@ version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "cfg-if", - "cfg_aliases", + "cfg_aliases 0.1.1", + "libc", +] + +[[package]] +name = "nix" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +dependencies = [ + "bitflags 2.6.0", + "cfg-if", + "cfg_aliases 0.2.1", "libc", ] @@ -2093,9 +2482,9 @@ dependencies = [ [[package]] name = "num-bigint" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c165a9ab64cf766f73521c0dd2cfdff64f488b8f0b3e621face3462d3db536d7" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" dependencies = [ "num-integer", "num-traits", @@ -2130,7 +2519,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.79", ] [[package]] @@ -2180,7 +2569,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi", + "hermit-abi 0.3.9", "libc", ] @@ -2201,18 +2590,21 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" [[package]] name = "object" -version = "0.36.0" +version = "0.36.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "576dfe1fc8f9df304abb159d767a29d0476f7750fbf8aa7ad07816004a207434" +checksum = "084f1a5821ac4c651660a94a7153d27ac9d8a53736203f58b31945ded098070a" dependencies = [ "memchr", ] [[package]] name = "once_cell" -version = "1.19.0" +version = "1.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +checksum = "82881c4be219ab5faaf2ad5e5e5ecdff8c66bd7402ca3160975c93b24961afd1" +dependencies = [ + "portable-atomic", +] [[package]] name = "onig" @@ -2237,12 +2629,18 @@ dependencies = [ ] [[package]] -name = "openssl" -version = "0.10.64" +name = "oorandom" +version = "11.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" +checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" + +[[package]] +name = "openssl" +version = "0.10.66" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9529f4786b70a3e8c61e11179af17ab6188ad8d0ded78c5529441ed39d4bd9c1" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "cfg-if", "foreign-types", "libc", @@ -2259,7 +2657,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.79", ] [[package]] @@ -2270,9 +2668,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.102" +version = "0.9.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c597637d56fbc83893a35eb0dd04b2b8e7a50c91e64e9493e398b5df4fb45fa2" +checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6" dependencies = [ "cc", "libc", @@ -2287,19 +2685,37 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9591d937bc0e6d2feb6f71a559540ab300ea49955229c347a517a28d27784c54" dependencies = [ "opentelemetry_api", - "opentelemetry_sdk", + "opentelemetry_sdk 0.20.0", ] [[package]] -name = "opentelemetry-http" -version = "0.9.0" +name = "opentelemetry" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7594ec0e11d8e33faf03530a4c49af7064ebba81c1480e01be67d90b356508b" +checksum = "1e32339a5dc40459130b3bd269e9892439f55b33e772d2a9d402a789baaf4e8a" dependencies = [ - "async-trait", - "bytes", - "http", - "opentelemetry_api", + "futures-core", + "futures-sink", + "indexmap 2.5.0", + "js-sys", + "once_cell", + "pin-project-lite", + "thiserror", + "urlencoding", +] + +[[package]] +name = "opentelemetry" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b69a91d4893e713e06f724597ad630f1fa76057a5e1026c0ca67054a9032a76" +dependencies = [ + "futures-core", + "futures-sink", + "js-sys", + "once_cell", + "pin-project-lite", + "thiserror", ] [[package]] @@ -2310,11 +2726,11 @@ checksum = "7e5e5a5c4135864099f3faafbe939eb4d7f9b80ebf68a8448da961b32a7c1275" dependencies = [ "async-trait", "futures-core", - "http", + "http 0.2.12", "opentelemetry-proto", "opentelemetry-semantic-conventions", "opentelemetry_api", - "opentelemetry_sdk", + "opentelemetry_sdk 0.20.0", "prost 0.11.9", "thiserror", "tokio", @@ -2328,7 +2744,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1e3f814aa9f8c905d0ee4bde026afd3b2577a97c10e1699912e3e44f0c4cbeb" dependencies = [ "opentelemetry_api", - "opentelemetry_sdk", + "opentelemetry_sdk 0.20.0", "prost 0.11.9", "tonic 0.9.2", ] @@ -2339,7 +2755,7 @@ version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73c9f9340ad135068800e7f1b24e9e09ed9e7143f5bf8518ded3d3ec69789269" dependencies = [ - "opentelemetry", + "opentelemetry 0.20.0", ] [[package]] @@ -2371,7 +2787,7 @@ dependencies = [ "futures-util", "once_cell", "opentelemetry_api", - "ordered-float", + "ordered-float 3.9.2", "percent-encoding", "rand", "regex", @@ -2381,6 +2797,46 @@ dependencies = [ "tokio-stream", ] +[[package]] +name = "opentelemetry_sdk" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f16aec8a98a457a52664d69e0091bac3a0abd18ead9b641cb00202ba4e0efe4" +dependencies = [ + "async-trait", + "crossbeam-channel", + "futures-channel", + "futures-executor", + "futures-util", + "glob", + "once_cell", + "opentelemetry 0.21.0", + "ordered-float 4.3.0", + "percent-encoding", + "rand", + "thiserror", +] + +[[package]] +name = "opentelemetry_sdk" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae312d58eaa90a82d2e627fd86e075cf5230b3f11794e2ed74199ebbe572d4fd" +dependencies = [ + "async-trait", + "futures-channel", + "futures-executor", + "futures-util", + "glob", + "lazy_static", + "once_cell", + "opentelemetry 0.23.0", + "ordered-float 4.3.0", + "percent-encoding", + "rand", + "thiserror", +] + [[package]] name = "option-ext" version = "0.2.0" @@ -2396,6 +2852,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "ordered-float" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d501f1a72f71d3c063a6bbc8f7271fa73aa09fe5d6283b6571e2ed176a2537" +dependencies = [ + "num-traits", +] + [[package]] name = "overload" version = "0.1.1" @@ -2433,7 +2898,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -2455,7 +2920,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ "fixedbitset", - "indexmap 2.2.6", + "indexmap 2.5.0", ] [[package]] @@ -2475,7 +2940,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.79", ] [[package]] @@ -2492,28 +2957,56 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" +checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" + +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] [[package]] name = "png" -version = "0.17.13" +version = "0.17.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06e4b0d3d1312775e782c86c91a111aa1f910cbb65e1337f9975b5f9a554b5e1" +checksum = "52f9d46a34a05a6a57566bc2bfae066ef07585a6e3fa30fbbdff5936380623f0" dependencies = [ "bitflags 1.3.2", "crc32fast", "fdeflate", "flate2", - "miniz_oxide", + "miniz_oxide 0.8.0", ] [[package]] name = "portable-atomic" -version = "1.6.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" +checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" [[package]] name = "powerfmt" @@ -2523,18 +3016,21 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" -version = "0.2.17" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] [[package]] name = "prettyplease" -version = "0.2.20" +version = "0.2.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e" +checksum = "479cf940fbbb3426c32c5d5176f62ad57549a0bb84773423ba8be9d089f5faba" dependencies = [ "proc-macro2", - "syn 2.0.66", + "syn 2.0.79", ] [[package]] @@ -2563,9 +3059,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.85" +version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22244ce15aa966053a896d1accb3a6e68469b97c7f33f284b99f0d576879fc23" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" dependencies = [ "unicode-ident", ] @@ -2586,7 +3082,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd" dependencies = [ "quote", - "syn 2.0.66", + "syn 2.0.79", ] [[package]] @@ -2616,8 +3112,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" dependencies = [ "bytes", - "heck 0.4.1", - "itertools 0.10.5", + "heck 0.5.0", + "itertools 0.12.1", "log", "multimap", "once_cell", @@ -2626,7 +3122,7 @@ dependencies = [ "prost 0.12.6", "prost-types", "regex", - "syn 2.0.66", + "syn 2.0.79", "tempfile", ] @@ -2650,10 +3146,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1" dependencies = [ "anyhow", - "itertools 0.10.5", + "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.79", ] [[package]] @@ -2665,6 +3161,69 @@ dependencies = [ "prost 0.12.6", ] +[[package]] +name = "pyo3" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15ee168e30649f7f234c3d49ef5a7a6cbf5134289bc46c29ff3155fa3221c225" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e61cef80755fe9e46bb8a0b8f20752ca7676dcc07a5277d8b7768c6172e529b3" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67ce096073ec5405f5ee2b8b31f03a68e02aa10d5d4f565eca04acc41931fa1c" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2440c6d12bc8f3ae39f1e775266fa5122fd0c8891ce7520fa6048e683ad3de28" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 2.0.79", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1be962f0e06da8f8465729ea2cb71a416d2257dff56cbe40a70d3e62a93ae5d1" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn 2.0.79", +] + [[package]] name = "qoi" version = "0.4.1" @@ -2676,13 +3235,12 @@ dependencies = [ [[package]] name = "quanta" -version = "0.11.1" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a17e662a7a8291a865152364c20c7abc5e60486ab2001e8ec10b24862de0b9ab" +checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5" dependencies = [ "crossbeam-utils", "libc", - "mach2", "once_cell", "raw-cpuid", "wasi", @@ -2698,9 +3256,9 @@ checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" [[package]] name = "quote" -version = "1.0.36" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ "proc-macro2", ] @@ -2737,18 +3295,22 @@ dependencies = [ [[package]] name = "ratatui" -version = "0.23.0" +version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e2e4cd95294a85c3b4446e63ef054eea43e0205b1fd60120c16b74ff7ff96ad" +checksum = "fdef7f9be5c0122f890d58bdf4d964349ba6a6161f705907526d891efabba57d" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "cassowary", + "compact_str", "crossterm", - "indoc", - "itertools 0.11.0", + "instability", + "itertools 0.13.0", + "lru", "paste", "strum", + "strum_macros", "unicode-segmentation", + "unicode-truncate", "unicode-width", ] @@ -2789,26 +3351,25 @@ dependencies = [ [[package]] name = "ravif" -version = "0.11.5" +version = "0.11.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc13288f5ab39e6d7c9d501759712e6969fcc9734220846fc9ed26cae2cc4234" +checksum = "a8f0bfd976333248de2078d350bfdf182ff96e168a24d23d2436cef320dd4bdd" dependencies = [ "avif-serialize", "imgref", "loop9", "quick-error", "rav1e", - "rayon", "rgb", ] [[package]] name = "raw-cpuid" -version = "10.7.0" +version = "11.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" +checksum = "cb9ee317cfe3fbd54b36a511efc1edd42e216903c9cd575e686dd68a2ba90d8d" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.6.0", ] [[package]] @@ -2844,18 +3405,18 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.2" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c82cf8cff14456045f55ec4241383baeff27af886adb72ffb2162f99911de0fd" +checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", ] [[package]] name = "redox_users" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891" +checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ "getrandom", "libredox", @@ -2864,14 +3425,14 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.5" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" +checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.7", - "regex-syntax 0.8.4", + "regex-automata 0.4.8", + "regex-syntax 0.8.5", ] [[package]] @@ -2885,13 +3446,13 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.7" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.4", + "regex-syntax 0.8.5", ] [[package]] @@ -2902,9 +3463,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "regex-syntax" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "reqwest" @@ -2917,10 +3478,10 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "h2", - "http", - "http-body", - "hyper", + "h2 0.3.26", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.30", "hyper-tls", "ipnet", "js-sys", @@ -2930,11 +3491,11 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls-pemfile", + "rustls-pemfile 1.0.4", "serde", "serde_json", "serde_urlencoded", - "sync_wrapper", + "sync_wrapper 0.1.2", "system-configuration", "tokio", "tokio-native-tls", @@ -2948,9 +3509,9 @@ dependencies = [ [[package]] name = "rgb" -version = "0.8.37" +version = "0.8.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05aaa8004b64fd573fc9d002f4e632d51ad4f026c2b5ba95fcb6c2f32c2c47d8" +checksum = "57397d16646700483b67d2dd6511d79318f9d057fdbd21a4066aeac8b41d310a" dependencies = [ "bytemuck", ] @@ -2987,9 +3548,9 @@ dependencies = [ [[package]] name = "rust-embed" -version = "6.8.1" +version = "8.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a36224c3276f8c4ebc8c20f158eca7ca4359c8db89991c4925132aaaf6702661" +checksum = "fa66af4a4fdd5e7ebc276f115e895611a34739a9c1c01028383d612d550953c0" dependencies = [ "rust-embed-impl", "rust-embed-utils", @@ -2998,23 +3559,22 @@ dependencies = [ [[package]] name = "rust-embed-impl" -version = "6.8.1" +version = "8.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49b94b81e5b2c284684141a2fb9e2a31be90638caf040bf9afbc5a0416afe1ac" +checksum = "6125dbc8867951125eec87294137f4e9c2c96566e61bf72c45095a7c77761478" dependencies = [ "proc-macro2", "quote", "rust-embed-utils", - "shellexpand", - "syn 2.0.66", + "syn 2.0.79", "walkdir", ] [[package]] name = "rust-embed-utils" -version = "7.8.1" +version = "8.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d38ff6bf570dc3bb7100fce9f7b60c33fa71d80e88da3f2580df4ff2bdded74" +checksum = "2e5347777e9aacb56039b0e1f28785929a8a3b709e87482e7442c72e7c12529d" dependencies = [ "sha2", "walkdir", @@ -3027,21 +3587,27 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" [[package]] -name = "rustc_version" -version = "0.4.0" +name = "rustc-hash" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" dependencies = [ "semver", ] [[package]] name = "rustix" -version = "0.38.34" +version = "0.38.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" +checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "errno", "libc", "linux-raw-sys", @@ -3074,6 +3640,34 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls" +version = "0.23.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2dabaac7466917e566adb06783a81ca48944c6898a1b08b9374106dd671f4c8" +dependencies = [ + "aws-lc-rs", + "log", + "once_cell", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-native-certs" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcaf18a4f2be7326cd874a5fa579fae794320a0f388d365dca7e480e55f83f8a" +dependencies = [ + "openssl-probe", + "rustls-pemfile 2.1.3", + "rustls-pki-types", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pemfile" version = "1.0.4" @@ -3084,17 +3678,28 @@ dependencies = [ ] [[package]] -name = "rustls-pki-types" -version = "1.7.0" +name = "rustls-pemfile" +version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" +checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425" +dependencies = [ + "base64 0.22.1", + "rustls-pki-types", +] + +[[package]] +name = "rustls-pki-types" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e696e35370c65c9c541198af4543ccd580cf17fc25d8e05c5a242b202488c55" [[package]] name = "rustls-webpki" -version = "0.102.4" +version = "0.102.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff448f7e92e913c4b7d4c6d8e4540a1724b319b4152b8aef6d4cf8339712b33e" +checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" dependencies = [ + "aws-lc-rs", "ring 0.17.8", "rustls-pki-types", "untrusted 0.9.0", @@ -3123,11 +3728,11 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.23" +version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" +checksum = "e9aaafd5a2b6e3d657ff009d82fbd630b6bd54dd4eb06f21693925cdf80f9b8b" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -3136,6 +3741,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "scratch" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3cf7c11c38cb994f3d40e8a8cde3bbd1f72a435e4c49e85d6553d8312306152" + [[package]] name = "sct" version = "0.7.1" @@ -3148,11 +3759,11 @@ dependencies = [ [[package]] name = "security-framework" -version = "2.11.0" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "core-foundation", "core-foundation-sys", "libc", @@ -3161,9 +3772,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.11.0" +version = "2.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "317936bbbd05227752583946b9e66d7ce3b489f84e11a94a510b4437fef407d7" +checksum = "ea4a292869320c0272d7bc55a5a6aafaff59b4f63404a003887b679a2e05b4b6" dependencies = [ "core-foundation-sys", "libc", @@ -3180,31 +3791,42 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.203" +version = "1.0.210" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" +checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" dependencies = [ "serde_derive", ] [[package]] -name = "serde_derive" -version = "1.0.203" +name = "serde_cbor" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" +checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5" +dependencies = [ + "half 1.8.3", + "serde", +] + +[[package]] +name = "serde_derive" +version = "1.0.210" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.79", ] [[package]] name = "serde_json" -version = "1.0.117" +version = "1.0.128" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3" +checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" dependencies = [ "itoa", + "memchr", "ryu", "serde", ] @@ -3221,9 +3843,9 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "0.6.6" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79e674e01f999af37c49f70a6ede167a8a60b2503e56c5599532a65baa5969a0" +checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1" dependencies = [ "serde", ] @@ -3261,13 +3883,10 @@ dependencies = [ ] [[package]] -name = "shellexpand" -version = "2.1.2" +name = "shlex" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ccc8076840c4da029af4f87e4e8daeb0fca6b87bbb02e10cb60b791450e11e4" -dependencies = [ - "dirs 4.0.0", -] +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "signal-hook" @@ -3281,9 +3900,9 @@ dependencies = [ [[package]] name = "signal-hook-mio" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29ad2e15f37ec9a6cc544097b78a1ec90001e9f71b81338ca39f430adaca99af" +checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd" dependencies = [ "libc", "mio", @@ -3382,10 +4001,10 @@ dependencies = [ ] [[package]] -name = "stable_deref_trait" -version = "1.2.0" +name = "static_assertions" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" [[package]] name = "strsim" @@ -3395,31 +4014,31 @@ checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" [[package]] name = "strum" -version = "0.25.0" +version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" +checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" dependencies = [ "strum_macros", ] [[package]] name = "strum_macros" -version = "0.25.3" +version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck 0.4.1", + "heck 0.5.0", "proc-macro2", "quote", "rustversion", - "syn 2.0.66", + "syn 2.0.79", ] [[package]] name = "subtle" -version = "2.5.0" +version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" @@ -3434,9 +4053,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.66" +version = "2.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5" +checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" dependencies = [ "proc-macro2", "quote", @@ -3450,27 +4069,23 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" [[package]] -name = "synstructure" -version = "0.13.1" +name = "sync_wrapper" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.66", -] +checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" [[package]] name = "sysinfo" -version = "0.30.12" +version = "0.30.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "732ffa00f53e6b2af46208fba5718d9662a421049204e156328b66791ffa15ae" +checksum = "0a5b4ddaee55fb2bea2bf0e5000747e5f5c0de765e5a5ff87f4cd106439f4bb3" dependencies = [ "cfg-if", "core-foundation-sys", "libc", "ntapi", "once_cell", + "rayon", "windows", ] @@ -3534,29 +4149,61 @@ dependencies = [ [[package]] name = "target-lexicon" -version = "0.12.14" +version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "tempfile" -version = "3.10.1" +version = "3.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" +checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b" dependencies = [ "cfg-if", "fastrand", + "once_cell", "rustix", - "windows-sys 0.52.0", + "windows-sys 0.59.0", +] + +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "text-generation-backends-trtllm" +version = "2.3.1-dev0" +dependencies = [ + "async-stream", + "async-trait", + "clap 4.5.18", + "cmake", + "cxx", + "cxx-build", + "log", + "parking_lot", + "pkg-config", + "text-generation-router", + "thiserror", + "tokenizers 0.19.1", + "tokio", + "tokio-stream", + "tracing", + "tracing-opentelemetry 0.24.0", + "tracing-subscriber", ] [[package]] name = "text-generation-benchmark" -version = "2.0.4" +version = "2.3.1-dev0" dependencies = [ "average", - "clap", - "crossterm", + "clap 4.5.18", "float-ord", "hf-hub", "ratatui", @@ -3565,7 +4212,7 @@ dependencies = [ "tabled", "text-generation-client", "thiserror", - "tokenizers", + "tokenizers 0.20.0", "tokio", "tracing", "tracing-subscriber", @@ -3573,31 +4220,34 @@ dependencies = [ [[package]] name = "text-generation-client" -version = "2.0.4" +version = "2.3.1-dev0" dependencies = [ + "async-trait", + "base64 0.22.1", "futures", "grpc-metadata", "prost 0.12.6", "prost-build", - "rand", "thiserror", "tokio", "tonic 0.10.2", "tonic-build", - "tower", + "tower 0.4.13", "tracing", ] [[package]] name = "text-generation-launcher" -version = "2.0.4" +version = "2.3.1-dev0" dependencies = [ - "clap", + "clap 4.5.18", "ctrlc", "float_eq", "hf-hub", - "nix", + "nix 0.28.0", "once_cell", + "pyo3", + "regex", "reqwest", "serde", "serde_json", @@ -3609,64 +4259,180 @@ dependencies = [ [[package]] name = "text-generation-router" -version = "2.0.4" +version = "2.3.1-dev0" dependencies = [ "async-stream", - "axum", + "async-trait", + "axum 0.7.7", "axum-tracing-opentelemetry", "base64 0.22.1", - "clap", + "clap 4.5.18", + "csv", "futures", "futures-util", "hf-hub", "image", "init-tracing-opentelemetry", + "itertools 0.10.5", + "jsonschema", + "metrics", + "metrics-exporter-prometheus", + "minijinja", + "minijinja-contrib", + "ngrok", + "nohash-hasher", + "once_cell", + "opentelemetry 0.20.0", + "opentelemetry-otlp", + "pyo3", + "rand", + "regex", + "reqwest", + "serde", + "serde_json", + "sysinfo", + "thiserror", + "tokenizers 0.20.0", + "tokio", + "tokio-stream", + "tower-http", + "tracing", + "tracing-opentelemetry 0.21.0", + "tracing-subscriber", + "ureq", + "utoipa", + "utoipa-swagger-ui", + "uuid", + "vergen", +] + +[[package]] +name = "text-generation-router-v2" +version = "2.3.1-dev0" +dependencies = [ + "async-stream", + "async-trait", + "axum 0.7.7", + "axum-tracing-opentelemetry", + "base64 0.22.1", + "clap 4.5.18", + "futures", + "futures-util", + "grpc-metadata", + "hf-hub", + "image", + "init-tracing-opentelemetry", "jsonschema", "metrics", "metrics-exporter-prometheus", "minijinja", - "ngrok", + "minijinja-contrib", "nohash-hasher", "once_cell", - "opentelemetry", + "opentelemetry 0.20.0", "opentelemetry-otlp", + "prost 0.12.6", + "prost-build", "rand", "regex", "reqwest", "serde", "serde_json", - "text-generation-client", + "slotmap", + "text-generation-router", "thiserror", - "tokenizers", + "tokenizers 0.20.0", "tokio", "tokio-stream", + "tonic 0.10.2", + "tonic-build", + "tower 0.4.13", "tower-http", "tracing", - "tracing-opentelemetry", + "tracing-opentelemetry 0.21.0", "tracing-subscriber", "utoipa", "utoipa-swagger-ui", - "vergen", +] + +[[package]] +name = "text-generation-router-v3" +version = "2.3.1-dev0" +dependencies = [ + "async-stream", + "async-trait", + "axum 0.7.7", + "axum-tracing-opentelemetry", + "base64 0.22.1", + "clap 4.5.18", + "criterion", + "futures", + "futures-util", + "grpc-metadata", + "hf-hub", + "image", + "init-tracing-opentelemetry", + "itertools 0.13.0", + "jsonschema", + "metrics", + "metrics-exporter-prometheus", + "minijinja", + "minijinja-contrib", + "nohash-hasher", + "once_cell", + "opentelemetry 0.20.0", + "opentelemetry-otlp", + "prost 0.12.6", + "prost-build", + "rand", + "regex", + "reqwest", + "serde", + "serde_json", + "slotmap", + "text-generation-router", + "thiserror", + "tokenizers 0.20.0", + "tokio", + "tokio-stream", + "tonic 0.10.2", + "tonic-build", + "tower 0.4.13", + "tower-http", + "tracing", + "tracing-opentelemetry 0.21.0", + "tracing-subscriber", + "utoipa", + "utoipa-swagger-ui", +] + +[[package]] +name = "textwrap" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" +dependencies = [ + "unicode-width", ] [[package]] name = "thiserror" -version = "1.0.61" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" +checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.61" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" +checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.79", ] [[package]] @@ -3724,15 +4490,30 @@ dependencies = [ ] [[package]] -name = "tinystr" -version = "0.7.6" +name = "tinytemplate" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" dependencies = [ - "displaydoc", - "zerovec", + "serde", + "serde_json", ] +[[package]] +name = "tinyvec" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "tokenizers" version = "0.19.1" @@ -3756,7 +4537,40 @@ dependencies = [ "rayon", "rayon-cond", "regex", - "regex-syntax 0.8.4", + "regex-syntax 0.8.5", + "serde", + "serde_json", + "spm_precompiled", + "thiserror", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + +[[package]] +name = "tokenizers" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8a24d7f7d6be5b9d1377418b893ab1808af0074f5d1bb2c64784452ddd2aa70" +dependencies = [ + "aho-corasick", + "derive_builder", + "esaxx-rs", + "getrandom", + "hf-hub", + "indicatif", + "itertools 0.12.1", + "lazy_static", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand", + "rayon", + "rayon-cond", + "regex", + "regex-syntax 0.8.5", "serde", "serde_json", "spm_precompiled", @@ -3768,21 +4582,20 @@ dependencies = [ [[package]] name = "tokio" -version = "1.38.0" +version = "1.40.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba4f4a02a7a80d6f274636f0aa95c7e383b912d41fe721a31f29e29698585a4a" +checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998" dependencies = [ "backtrace", "bytes", "libc", "mio", - "num_cpus", "parking_lot", "pin-project-lite", "signal-hook-registry", "socket2", "tokio-macros", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -3797,13 +4610,13 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.3.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" +checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.79", ] [[package]] @@ -3828,10 +4641,21 @@ dependencies = [ ] [[package]] -name = "tokio-stream" -version = "0.1.15" +name = "tokio-rustls" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" +checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" +dependencies = [ + "rustls 0.23.13", + "rustls-pki-types", + "tokio", +] + +[[package]] +name = "tokio-stream" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f4e6ce100d0eb49a2734f8c0812bcd324cf357d21810932c5df6b96ef2b86f1" dependencies = [ "futures-core", "pin-project-lite", @@ -3840,9 +4664,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.11" +version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1" +checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" dependencies = [ "bytes", "futures-core", @@ -3854,9 +4678,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.14" +version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f49eb2ab21d2f26bd6db7bf383edc527a7ebaee412d17af4d40fdccd442f335" +checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e" dependencies = [ "serde", "serde_spanned", @@ -3866,20 +4690,20 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.6" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4badfd56924ae69bcc9039335b2e017639ce3f9b001c393c1b2d1ef846ce2cbf" +checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" dependencies = [ "serde", ] [[package]] name = "toml_edit" -version = "0.22.14" +version = "0.22.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f21c7aaf97f1bd9ca9d4f9e73b0a6c74bd5afef56f2bc931943a6e1c37e04e38" +checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" dependencies = [ - "indexmap 2.2.6", + "indexmap 2.5.0", "serde", "serde_spanned", "toml_datetime", @@ -3893,22 +4717,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3082666a3a6433f7f511c7192923fa1fe07c69332d3c6a2e6bb040b569199d5a" dependencies = [ "async-trait", - "axum", + "axum 0.6.20", "base64 0.21.7", "bytes", "futures-core", "futures-util", - "h2", - "http", - "http-body", - "hyper", + "h2 0.3.26", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.30", "hyper-timeout", "percent-encoding", "pin-project", "prost 0.11.9", "tokio", "tokio-stream", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", "tracing", @@ -3922,20 +4746,20 @@ checksum = "d560933a0de61cf715926b9cac824d4c883c2c43142f787595e48280c40a1d0e" dependencies = [ "async-stream", "async-trait", - "axum", + "axum 0.6.20", "base64 0.21.7", "bytes", - "h2", - "http", - "http-body", - "hyper", + "h2 0.3.26", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.30", "hyper-timeout", "percent-encoding", "pin-project", "prost 0.12.6", "tokio", "tokio-stream", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", "tracing", @@ -3951,7 +4775,7 @@ dependencies = [ "proc-macro2", "prost-build", "quote", - "syn 2.0.66", + "syn 2.0.79", ] [[package]] @@ -3975,18 +4799,32 @@ dependencies = [ ] [[package]] -name = "tower-http" -version = "0.4.4" +name = "tower" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61c5bb1d698276a2443e5ecfabc1008bf15a36c12e6a7176e7bf089ea9131140" +checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f" dependencies = [ - "bitflags 2.5.0", - "bytes", "futures-core", "futures-util", - "http", - "http-body", - "http-range-header", + "pin-project-lite", + "sync_wrapper 0.1.2", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" +dependencies = [ + "bitflags 2.6.0", + "bytes", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", "pin-project-lite", "tower-layer", "tower-service", @@ -3994,15 +4832,15 @@ dependencies = [ [[package]] name = "tower-layer" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" [[package]] name = "tower-service" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" @@ -4024,7 +4862,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.79", ] [[package]] @@ -4066,8 +4904,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75327c6b667828ddc28f5e3f169036cb793c3f588d83bf0f262a7f062ffed3c8" dependencies = [ "once_cell", - "opentelemetry", - "opentelemetry_sdk", + "opentelemetry 0.20.0", + "opentelemetry_sdk 0.20.0", "smallvec", "tracing", "tracing-core", @@ -4076,16 +4914,51 @@ dependencies = [ ] [[package]] -name = "tracing-opentelemetry-instrumentation-sdk" -version = "0.14.2" +name = "tracing-opentelemetry" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f523eba1b52bb854b804d43a039aafeaee5a623015065adbfef8016825319c15" +checksum = "c67ac25c5407e7b961fafc6f7e9aa5958fd297aada2d20fa2ae1737357e55596" dependencies = [ - "http", - "opentelemetry-http", - "opentelemetry_api", + "js-sys", + "once_cell", + "opentelemetry 0.21.0", + "opentelemetry_sdk 0.21.2", + "smallvec", "tracing", - "tracing-opentelemetry", + "tracing-core", + "tracing-log 0.2.0", + "tracing-subscriber", + "web-time 0.2.4", +] + +[[package]] +name = "tracing-opentelemetry" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f68803492bf28ab40aeccaecc7021096bd256baf7ca77c3d425d89b35a7be4e4" +dependencies = [ + "js-sys", + "once_cell", + "opentelemetry 0.23.0", + "opentelemetry_sdk 0.23.0", + "smallvec", + "tracing", + "tracing-core", + "tracing-log 0.2.0", + "tracing-subscriber", + "web-time 1.1.0", +] + +[[package]] +name = "tracing-opentelemetry-instrumentation-sdk" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9920abb6a3ee3a2af7d30c9ff02900f8481935d36723c3da95cf807468218e8c" +dependencies = [ + "http 1.1.0", + "opentelemetry 0.21.0", + "tracing", + "tracing-opentelemetry 0.22.0", ] [[package]] @@ -4141,10 +5014,25 @@ dependencies = [ ] [[package]] -name = "unicode-ident" -version = "1.0.12" +name = "unicode-bidi" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" + +[[package]] +name = "unicode-ident" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" + +[[package]] +name = "unicode-normalization" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" +dependencies = [ + "tinyvec", +] [[package]] name = "unicode-normalization-alignments" @@ -4157,15 +5045,26 @@ dependencies = [ [[package]] name = "unicode-segmentation" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" + +[[package]] +name = "unicode-truncate" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3644627a5af5fa321c95b9b235a72fd24cd29c648c2c379431e6628655627bf" +dependencies = [ + "itertools 0.13.0", + "unicode-segmentation", + "unicode-width", +] [[package]] name = "unicode-width" -version = "0.1.13" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" [[package]] name = "unicode_categories" @@ -4173,6 +5072,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" +[[package]] +name = "unindent" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" + [[package]] name = "untrusted" version = "0.7.1" @@ -4207,9 +5112,9 @@ dependencies = [ [[package]] name = "url" -version = "2.5.1" +version = "2.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7c25da092f0a868cdf09e8674cd3b7ef3a7d92a24253e663a2fb85e2496de56" +checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" dependencies = [ "form_urlencoded", "idna", @@ -4222,18 +5127,6 @@ version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" -[[package]] -name = "utf16_iter" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" - -[[package]] -name = "utf8_iter" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" - [[package]] name = "utf8parse" version = "0.2.2" @@ -4242,11 +5135,11 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "utoipa" -version = "3.5.0" +version = "4.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d82b1bc5417102a73e8464c686eef947bdfb99fcdfc0a4f228e81afa9526470a" +checksum = "c5afb1a60e207dca502682537fefcfd9921e71d0b83e9576060f09abc6efab23" dependencies = [ - "indexmap 2.2.6", + "indexmap 2.5.0", "serde", "serde_json", "utoipa-gen", @@ -4254,24 +5147,24 @@ dependencies = [ [[package]] name = "utoipa-gen" -version = "3.5.0" +version = "4.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05d96dcd6fc96f3df9b3280ef480770af1b7c5d14bc55192baa9b067976d920c" +checksum = "7bf0e16c02bc4bf5322ab65f10ab1149bdbcaa782cba66dc7057370a3f8190be" dependencies = [ "proc-macro-error", "proc-macro2", "quote", "regex", - "syn 2.0.66", + "syn 2.0.79", ] [[package]] name = "utoipa-swagger-ui" -version = "3.1.5" +version = "6.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84614caa239fb25b2bb373a52859ffd94605ceb256eeb1d63436325cf81e3653" +checksum = "0b39868d43c011961e04b41623e050aedf2cc93652562ff7935ce0f819aaf2da" dependencies = [ - "axum", + "axum 0.7.7", "mime_guess", "regex", "rust-embed", @@ -4283,9 +5176,25 @@ dependencies = [ [[package]] name = "uuid" -version = "1.8.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a183cf7feeba97b4dd1c0d46788634f6221d87fa961b305bed08c851829efcc0" +checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +dependencies = [ + "getrandom", + "rand", + "uuid-macro-internal", +] + +[[package]] +name = "uuid-macro-internal" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee1cd046f83ea2c4e920d6ee9f7c3537ef928d75dce5d84a87c2c5d6b3999a3a" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] [[package]] name = "v_frame" @@ -4312,9 +5221,9 @@ checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" [[package]] name = "vergen" -version = "8.3.1" +version = "8.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e27d6bdd219887a9eadd19e1c34f32e47fa332301184935c6d9bca26f3cca525" +checksum = "2990d9ea5967266ea0ccf413a4aa5c42a93dbcfda9cb49a97de6931726b12566" dependencies = [ "anyhow", "cargo_metadata", @@ -4334,9 +5243,9 @@ checksum = "852e951cb7832cb45cb1169900d19760cfa39b82bc0ea9c0e5a14ae88411c98b" [[package]] name = "version_check" -version = "0.9.4" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" [[package]] name = "walkdir" @@ -4365,34 +5274,35 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" dependencies = [ "cfg-if", + "once_cell", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.79", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.42" +version = "0.4.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76bc14366121efc8dbb487ab05bcc9d346b3b5ec0eaa76e46594cabbe51762c0" +checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed" dependencies = [ "cfg-if", "js-sys", @@ -4402,9 +5312,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -4412,28 +5322,48 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.79", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" [[package]] name = "web-sys" -version = "0.3.69" +version = "0.3.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "web-time" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa30049b1c872b72c89866d458eae9f20380ab280ffd1b1e18df2d3e2d98cfe0" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" dependencies = [ "js-sys", "wasm-bindgen", @@ -4451,9 +5381,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.26.2" +version = "0.26.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c452ad30530b54a4d8e71952716a212b08efd0f3562baa66c29a618b07da7c3" +checksum = "841c67bff177718f1d4dfefde8d8f0e78f9b6589319ba88312f567fc5841a958" dependencies = [ "rustls-pki-types", ] @@ -4464,6 +5394,18 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix", +] + [[package]] name = "winapi" version = "0.3.9" @@ -4482,11 +5424,11 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] name = "winapi-util" -version = "0.1.8" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -4502,7 +5444,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" dependencies = [ "windows-core", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -4511,7 +5453,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -4538,7 +5480,16 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", ] [[package]] @@ -4573,18 +5524,18 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm 0.52.5", - "windows_aarch64_msvc 0.52.5", - "windows_i686_gnu 0.52.5", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", "windows_i686_gnullvm", - "windows_i686_msvc 0.52.5", - "windows_x86_64_gnu 0.52.5", - "windows_x86_64_gnullvm 0.52.5", - "windows_x86_64_msvc 0.52.5", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", ] [[package]] @@ -4601,9 +5552,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_msvc" @@ -4619,9 +5570,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_i686_gnu" @@ -4637,15 +5588,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_msvc" @@ -4661,9 +5612,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_x86_64_gnu" @@ -4679,9 +5630,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnullvm" @@ -4697,9 +5648,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_msvc" @@ -4715,15 +5666,15 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.13" +version = "0.6.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59b5e5f6c299a3c7890b876a2a587f3115162487e704907d9b6cd29473052ba1" +checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" dependencies = [ "memchr", ] @@ -4738,81 +5689,25 @@ dependencies = [ "windows-sys 0.48.0", ] -[[package]] -name = "write16" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" - -[[package]] -name = "writeable" -version = "0.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" - -[[package]] -name = "yoke" -version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c5b1314b079b0930c31e3af543d8ee1757b1951ae1e1565ec704403a7240ca5" -dependencies = [ - "serde", - "stable_deref_trait", - "yoke-derive", - "zerofrom", -] - -[[package]] -name = "yoke-derive" -version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.66", - "synstructure", -] - [[package]] name = "zerocopy" -version = "0.7.34" +version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ + "byteorder", "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.34" +version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", -] - -[[package]] -name = "zerofrom" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91ec111ce797d0e0784a1116d0ddcdbea84322cd79e5d5ad173daeba4f93ab55" -dependencies = [ - "zerofrom-derive", -] - -[[package]] -name = "zerofrom-derive" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.66", - "synstructure", + "syn 2.0.79", ] [[package]] @@ -4821,28 +5716,6 @@ version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" -[[package]] -name = "zerovec" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb2cc8827d6c0994478a15c53f374f46fbd41bea663d809b14744bc42e6b109c" -dependencies = [ - "yoke", - "zerofrom", - "zerovec-derive", -] - -[[package]] -name = "zerovec-derive" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97cf56601ee5052b4417d90c8755c6683473c926039908196cf35d99f893ebe7" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.66", -] - [[package]] name = "zip" version = "0.6.6" @@ -4872,9 +5745,9 @@ dependencies = [ [[package]] name = "zune-jpeg" -version = "0.4.11" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec866b44a2a1fd6133d363f073ca1b179f438f99e7e5bfb1e33f7181facfe448" +checksum = "16099418600b4d8f028622f73ff6e3deaabdff330fb9a2a131dea781ee8b0768" dependencies = [ "zune-core", ] diff --git a/Cargo.toml b/Cargo.toml index 83972519..a51c6d04 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,27 +1,53 @@ [workspace] members = [ - "benchmark", - "router", - "router/client", - "router/grpc-metadata", - "launcher" + "benchmark", + "backends/v2", + "backends/v3", + "backends/grpc-metadata", + "backends/trtllm", + "launcher", + "router" +] +default-members = [ + "benchmark", + "backends/v2", + "backends/v3", + "backends/grpc-metadata", + # "backends/trtllm", + "launcher", + "router" ] resolver = "2" [workspace.package] -version = "2.0.4" +version = "2.3.1" edition = "2021" authors = ["Olivier Dehaene"] homepage = "https://github.com/huggingface/text-generation-inference" [workspace.dependencies] +base64 = "0.22.0" tokenizers = { version = "0.20.0", features = ["http"] } hf-hub = { version = "0.3.1", features = ["tokio"] } +metrics = { version = "0.23.0" } +metrics-exporter-prometheus = { version = "0.15.1", features = [] } +minijinja = { version = "2.2.0", features = ["json"] } +minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } +pyo3 = { version = "0.22.2", features = ["auto-initialize"] } [profile.release] +incremental = true + +[profile.release-binary] +inherits = "release" debug = 1 incremental = true +panic = "abort" + +[profile.release-opt] +inherits = "release" +debug = 0 +incremental = false lto = "fat" opt-level = 3 codegen-units = 1 -panic = "abort" diff --git a/Dockerfile b/Dockerfile index d8198a3a..9276377f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,18 +1,24 @@ # Rust builder -FROM lukemathwalker/cargo-chef:latest-rust-1.78 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef WORKDIR /usr/src -FROM chef as planner +ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse + +FROM chef AS planner +COPY Cargo.lock Cargo.lock COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo chef prepare --recipe-path recipe.json FROM chef AS builder +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + python3.11-dev RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ @@ -20,22 +26,29 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ rm -f $PROTOC_ZIP COPY --from=planner /usr/src/recipe.json recipe.json -COPY Cargo.lock Cargo.lock -RUN cargo chef cook --release --recipe-path recipe.json +RUN cargo chef cook --profile release-opt --recipe-path recipe.json + +ARG GIT_SHA +ARG DOCKER_LABEL COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher -RUN cargo build --release +RUN cargo build --profile release-opt # Text Generation Inference base image FROM vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest as base +ENV ATTENTION=default +ENV PREFIX_CACHING=0 +ENV PREFILL_CHUNKING=0 + # Text Generation Inference base env -ENV HUGGINGFACE_HUB_CACHE=/data \ +ENV HF_HOME=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ PORT=80 @@ -51,6 +64,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins make \ curl \ git \ + python3.11-dev \ && rm -rf /var/lib/apt/lists/* # Install server @@ -59,21 +73,33 @@ COPY server server COPY server/Makefile server/Makefile RUN cd server && \ make gen-server && \ - pip install -r requirements.txt && \ + pip install --no-deps -r requirements.txt && \ bash ./dill-0.3.8-patch.sh && \ pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.18.0 && \ BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \ pip install . --no-cache-dir # Install benchmarker -COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark +COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark # Install router -COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router +COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router # Install launcher -COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher +COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher + + +# AWS Sagemaker compatible image +FROM base AS sagemaker + +COPY sagemaker-entrypoint.sh entrypoint.sh +RUN chmod +x entrypoint.sh + +ENTRYPOINT ["./entrypoint.sh"] # Final image FROM base -ENTRYPOINT ["text-generation-launcher"] +COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh +RUN chmod +x /tgi-entrypoint.sh + +ENTRYPOINT ["/tgi-entrypoint.sh"] CMD ["--json-output"] diff --git a/Dockerfile.nix b/Dockerfile.nix new file mode 100644 index 00000000..f1e7e0f5 --- /dev/null +++ b/Dockerfile.nix @@ -0,0 +1,24 @@ +# Build the image and get out the docker file: +# +# docker build -t tgi-nix-builder -f Dockerfile.nix +# docker run --log-driver=none tgi-nix-builder | docker load + +FROM nixos/nix:2.18.8 AS builder +RUN echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf +RUN nix profile install nixpkgs#cachix +RUN cachix use text-generation-inference +WORKDIR /root +ADD . . +RUN nix build . +RUN mkdir /tmp/nix-store-closure +RUN cp -R $(nix-store -qR result/) /tmp/nix-store-closure + +FROM ubuntu:24.04 + +WORKDIR /app + +# Copy /nix/store +COPY --from=builder /tmp/nix-store-closure /nix/store +COPY --from=builder /root/result /app +RUN ldconfig +CMD ["ldconfig", "/app/bin/text-generation-launcher"] diff --git a/Dockerfile.trtllm b/Dockerfile.trtllm new file mode 100644 index 00000000..4543ae80 --- /dev/null +++ b/Dockerfile.trtllm @@ -0,0 +1,23 @@ +# All the tooling for CUDA +FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 AS cuda-builder + +WORKDIR /usr/src/tgi/backends/trtllm +RUN apt update && apt install -y cmake git git-lfs gcc g++ ninja-build libopenmpi-dev python3-dev python3-pip wget + +COPY . /usr/src/tgi +RUN chmod +x scripts/install_tensorrt.sh && scripts/install_tensorrt.sh +RUN cmake -G Ninja -B build -DTRT_LIB_DIR=/usr/local/tensorrt/lib -DTRT_INCLUDE_DIR=/usr/local/tensorrt/include . +RUN cmake --build build --parallel -t tgi_trtllm_backend_impl + +# All the tooling for Rust +FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef +WORKDIR /usr/src + +# Include CUDA related libraries and tools to the Rust based image +COPY --from=cuda-builder /usr/local/cuda /usr/local/cuda +COPY --from=cuda-builder /usr/local/tensorrt /usr/local/tensorrt +COPY --from=cuda-builder /usr/src/tgi/backends/trtllm/build /usr/local/tgi/trtllm/build +ENV PATH=/usr/local/cuda/bin:$PATH +ENV LD_LIBRARY_PATH=/usr/local/tensorrt/lib:$LD_LIBRARY_PATH + +RUN apt update && apt install -y cmake git gcc g++ ninja-build libopenmpi3 diff --git a/Dockerfile_amd b/Dockerfile_amd index 92dd0ea8..0b059f8c 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -1,23 +1,24 @@ # Rust builder -FROM lukemathwalker/cargo-chef:latest-rust-1.78 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse -FROM chef as planner +FROM chef AS planner +COPY Cargo.lock Cargo.lock COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo chef prepare --recipe-path recipe.json FROM chef AS builder -ARG GIT_SHA -ARG DOCKER_LABEL - +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + python3.11-dev RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ @@ -25,18 +26,22 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ rm -f $PROTOC_ZIP COPY --from=planner /usr/src/recipe.json recipe.json -RUN cargo chef cook --release --recipe-path recipe.json +RUN cargo chef cook --profile release-opt --recipe-path recipe.json + +ARG GIT_SHA +ARG DOCKER_LABEL COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher -RUN cargo build --release +RUN cargo build --profile release-opt # Text Generation Inference base image for RoCm -FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base +FROM rocm/dev-ubuntu-22.04:6.2 AS base RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ build-essential \ @@ -45,33 +50,34 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins curl \ git \ make \ + libmsgpack-dev \ libssl-dev \ + llvm-dev \ g++ \ # Needed to build VLLM & flash. rocthrust-dev \ hipsparse-dev \ hipblas-dev \ - hipblaslt-dev \ + hipcub-dev \ rocblas-dev \ hiprand-dev \ + hipfft-dev \ rocrand-dev \ miopen-hip-dev \ - hipfft-dev \ - hipcub-dev \ hipsolver-dev \ rccl-dev \ cmake \ - python3-dev && \ + python3.11-venv && \ rm -rf /var/lib/apt/lists/* # Keep in sync with `server/pyproject.toml ARG MAMBA_VERSION=23.1.0-1 -ARG PYTORCH_VERSION='2.3.0' -ARG ROCM_VERSION='6.0.2' -ARG PYTHON_VERSION='3.10.10' +ARG PYTHON_VERSION='3.11.10' # Automatically set by buildx ARG TARGETPLATFORM -ENV PATH /opt/conda/bin:$PATH +ENV PATH=/opt/conda/bin:$PATH + +ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" # TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. # Install mamba @@ -86,42 +92,141 @@ RUN chmod +x ~/mambaforge.sh && \ mamba init && \ rm ~/mambaforge.sh +# RUN conda install intel::mkl-static intel::mkl-include +# Install pytorch +# On arm64 we exit with an error code +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") exit 1 ;; \ + *) /opt/conda/bin/conda update -y conda && \ + /opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \ + esac && \ + /opt/conda/bin/conda clean -ya + # Install flash-attention, torch dependencies -RUN pip install numpy einops ninja --no-cache-dir +RUN python3 -m pip install --upgrade pip && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/* -RUN conda install intel::mkl-static intel::mkl-include -RUN pip uninstall -y triton && \ - git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \ - cd triton/python && \ - pip install . +RUN conda install mkl=2021 +ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/ -RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir -ARG _GLIBCXX_USE_CXX11_ABI="1" -ARG CMAKE_PREFIX_PATH="/opt/conda" +ARG COMMON_WORKDIR=/ +WORKDIR ${COMMON_WORKDIR} + + +# Install HIPBLASLt +FROM base AS build_hipblaslt +ARG HIPBLASLT_BRANCH="e6da924" +RUN git clone https://github.com/ROCm/hipBLASLt.git \ + && cd hipBLASLt \ + && git checkout ${HIPBLASLT_BRANCH} \ + && SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \ + && cd build/release \ + && make package + +FROM scratch AS export_hipblaslt +ARG COMMON_WORKDIR +COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb / + +# RCCL build stages +FROM base AS build_rccl +ARG RCCL_BRANCH="rocm-6.2.0" +RUN git clone https://github.com/ROCm/rccl \ + && cd rccl \ + && git checkout ${RCCL_BRANCH} \ + && ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH} +FROM scratch AS export_rccl +ARG COMMON_WORKDIR +COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb / + +# Triton build stages +FROM base AS build_triton +ARG TRITON_BRANCH="e192dba" +ARG TRITON_REPO="https://github.com/triton-lang/triton.git" +RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \ + && cd triton \ + && git checkout ${TRITON_BRANCH} \ + && cd python \ + && python3 setup.py bdist_wheel --dist-dir=dist +FROM scratch AS export_triton +ARG COMMON_WORKDIR +COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl / + +# # AMD-SMI build stages +FROM base AS build_amdsmi +RUN cd /opt/rocm/share/amd_smi \ + && pip wheel . --wheel-dir=dist +FROM scratch AS export_amdsmi +COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl / + + +FROM base as build_pytorch + +RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \ + if ls /install/*.deb; then \ + dpkg -i /install/*.deb \ + && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ + && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \ + fi + +ARG BUILD_ENVIRONMENT=pytorch-linux-jammy-rocm6.2-py3.11 ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" -ARG BUILD_CAFFE2="0" \ - BUILD_CAFFE2_OPS="0" \ - USE_CUDA="0" \ - USE_ROCM="1" \ - BUILD_TEST="0" \ - USE_FBGEMM="0" \ - USE_NNPACK="0" \ - USE_QNNPACK="0" \ - USE_XNNPACK="0" \ - USE_FLASH_ATTENTION="1" \ - USE_MEM_EFF_ATTENTION="0" -RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install +# A commit to fix the output scaling factor issue in _scaled_mm +# Not yet in 2.5.0-rc1 +ARG PYTORCH_BRANCH="cedc116" +ARG PYTORCH_VISION_BRANCH="v0.19.1" +ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git" -# Set as recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm -ENV HIP_FORCE_DEV_KERNARG=1 +RUN git clone ${PYTORCH_REPO} pytorch \ + && cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \ + && pip install -r requirements.txt --no-cache-dir \ + && python tools/amd_build/build_amd.py \ + && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist +FROM scratch as export_pytorch +ARG COMMON_WORKDIR +COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl / -# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK. -# However, Triton requires a tunning for each prompt length, which is prohibitive. -ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0 +FROM base AS install_deps -FROM base AS kernel-builder +ARG COMMON_WORKDIR + +# Install hipblaslt +RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \ + if ls /install/*.deb; then \ + dpkg -i /install/*.deb \ + && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ + && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \ + fi + +RUN --mount=type=bind,from=export_rccl,src=/,target=/install \ + if ls /install/*.deb; then \ + dpkg -i /install/*.deb \ + # RCCL needs to be installed twice + && dpkg -i /install/*.deb \ + && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \ + && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \ + fi + +RUN --mount=type=bind,from=export_triton,src=/,target=/install \ + if ls /install/*.whl; then \ + # Preemptively uninstall to prevent pip same-version no-installs + pip uninstall -y triton \ + && pip install /install/*.whl; \ + fi + +RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \ + # Preemptively uninstall to prevent pip same-version no-installs + pip uninstall -y amdsmi \ + && pip install /install/*.whl; + +RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \ + if ls /install/*.whl; then \ + # Preemptively uninstall to prevent pip same-version no-installs + pip uninstall -y torch torchvision \ + && pip install /install/*.whl; \ + fi + +FROM install_deps AS kernel-builder # # Build vllm kernels FROM kernel-builder AS vllm-builder @@ -142,46 +247,46 @@ COPY server/Makefile-flash-att-v2 Makefile RUN make build-flash-attention-v2-rocm # Build Transformers CUDA kernels (gpt-neox and bloom) -FROM kernel-builder as custom-kernels-builder +FROM kernel-builder AS custom-kernels-builder WORKDIR /usr/src COPY server/custom_kernels/ . RUN python setup.py build # Build exllama kernels -FROM kernel-builder as exllama-kernels-builder +FROM kernel-builder AS exllama-kernels-builder WORKDIR /usr/src COPY server/exllama_kernels/ . RUN python setup.py build # Build exllama v2 kernels -FROM kernel-builder as exllamav2-kernels-builder +FROM kernel-builder AS exllamav2-kernels-builder WORKDIR /usr/src COPY server/exllamav2_kernels/ . RUN python setup.py build -FROM base as base-copy +FROM install_deps AS base-copy # Text Generation Inference base env -ENV HUGGINGFACE_HUB_CACHE=/data \ +ENV HF_HOME=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ PORT=80 # Copy builds artifacts from vllm builder -COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages # Copy build artifacts from flash attention v2 builder -COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages # Copy build artifacts from custom kernels builder -COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages # Copy build artifacts from exllama kernels builder -COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages # Copy build artifacts from exllamav2 kernels builder -COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages # Install server COPY proto proto @@ -193,14 +298,15 @@ RUN cd server && \ pip install ".[accelerate, peft, outlines]" --no-cache-dir # Install benchmarker -COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark +COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark # Install router -COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router +COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router # Install launcher -COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher +COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher +ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/" # AWS Sagemaker compatible image -FROM base as sagemaker +FROM base AS sagemaker COPY sagemaker-entrypoint.sh entrypoint.sh RUN chmod +x entrypoint.sh @@ -210,6 +316,19 @@ ENTRYPOINT ["./entrypoint.sh"] # Final image FROM base-copy +# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm +ENV HIP_FORCE_DEV_KERNARG=1 + +# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK. +# However, Triton requires a tunning for each prompt length, which is prohibitive. +ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0 +ENV ROCM_USE_CUSTOM_PAGED_ATTN=1 +ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0 +ENV VLLM_MOE_PADDING=0 +ENV ATTENTION=paged +ENV USE_PREFIX_CACHING=0 +ENV ROCM_USE_SKINNY_GEMM=1 + COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh diff --git a/Dockerfile_intel b/Dockerfile_intel index 809992e1..7ab6bba1 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -1,22 +1,25 @@ -FROM lukemathwalker/cargo-chef:latest-rust-1.78 AS chef +ARG PLATFORM=xpu + +FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse -FROM chef as planner +FROM chef AS planner +COPY Cargo.lock Cargo.lock COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo chef prepare --recipe-path recipe.json FROM chef AS builder -ARG GIT_SHA -ARG DOCKER_LABEL - +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + python3.11-dev RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ @@ -24,21 +27,52 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ rm -f $PROTOC_ZIP COPY --from=planner /usr/src/recipe.json recipe.json -RUN cargo chef cook --release --recipe-path recipe.json +RUN cargo chef cook --profile release-opt --recipe-path recipe.json + +ARG GIT_SHA +ARG DOCKER_LABEL COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher -RUN cargo build --release +RUN cargo build --profile release-opt # Text Generation Inference base image for Intel -FROM intel/intel-extension-for-pytorch:2.1.30-xpu as base + +FROM intel/intel-extension-for-pytorch:2.3.110-xpu AS xpu USER root + +ARG MAMBA_VERSION=23.1.0-1 +ARG PYTHON_VERSION='3.11.10' +# Automatically set by buildx +ARG TARGETPLATFORM +ENV PATH /opt/conda/bin:$PATH + +# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. +# Install mamba +# translating Docker's TARGETPLATFORM into mamba arches +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") MAMBA_ARCH=aarch64 ;; \ + *) MAMBA_ARCH=x86_64 ;; \ + esac && \ + curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" +RUN chmod +x ~/mambaforge.sh && \ + bash ~/mambaforge.sh -b -p /opt/conda && \ + rm ~/mambaforge.sh + +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") exit 1 ;; \ + *) /opt/conda/bin/conda update -y conda && \ + /opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \ + esac && \ + /opt/conda/bin/conda clean -ya + # libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \ dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb @@ -48,17 +82,16 @@ RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dea RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list -RUN apt-get update && apt install -y intel-basekit xpu-smi +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y intel-basekit xpu-smi cmake ninja-build pciutils # Text Generation Inference base env -ENV HUGGINGFACE_HUB_CACHE=/data \ +ENV HF_HOME=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ PORT=80 WORKDIR /usr/src -RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl -RUN pip install intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl +RUN pip install torch==2.3.1+cxx11.abi torchvision==0.18.1+cxx11.abi torchaudio==2.3.1+cxx11.abi intel-extension-for-pytorch==2.3.110+xpu oneccl_bind_pt==2.3.100+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --no-cache-dir # Install server COPY proto proto @@ -66,26 +99,112 @@ COPY server server COPY server/Makefile server/Makefile RUN cd server && \ make gen-server && \ - pip install -r requirements_cuda.txt && \ + pip install -r requirements_intel.txt && \ pip install ".[accelerate, peft, outlines]" --no-cache-dir ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib -ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64: -ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin +ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:/opt/conda/lib +ENV PATH=/opt/conda/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV CCL_ZE_IPC_EXCHANGE=sockets +ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest +ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include # Install benchmarker -COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark +COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark # Install router -COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router +COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router # Install launcher -COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher +COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher -# Final image -FROM base +# Text Generation Inference base image for Intel-cpu +FROM ubuntu:22.04 AS cpu + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + curl \ + ca-certificates \ + make \ + g++ \ + git \ + wget \ + cmake \ + libnuma-dev + +ENV HUGGINGFACE_HUB_CACHE=/data \ + HF_HUB_ENABLE_HF_TRANSFER=1 \ + PORT=80 + +ARG MAMBA_VERSION=23.1.0-1 +ARG PYTHON_VERSION='3.11.10' +# Automatically set by buildx +ARG TARGETPLATFORM +ENV PATH /opt/conda/bin:$PATH + +# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. +# Install mamba +# translating Docker's TARGETPLATFORM into mamba arches +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") MAMBA_ARCH=aarch64 ;; \ + *) MAMBA_ARCH=x86_64 ;; \ + esac && \ + curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" +RUN chmod +x ~/mambaforge.sh && \ + bash ~/mambaforge.sh -b -p /opt/conda && \ + rm ~/mambaforge.sh + +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") exit 1 ;; \ + *) /opt/conda/bin/conda update -y conda && \ + /opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \ + esac && \ + /opt/conda/bin/conda clean -ya + +RUN conda install -c conda-forge gperftools mkl + +RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp311-cp311-linux_x86_64.whl +RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp311-cp311-linux_x86_64.whl +RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp311-cp311-linux_x86_64.whl +RUN pip install triton py-libnuma + +WORKDIR /usr/src + +RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a + +RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131 + +RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install + +RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install . + +ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so +ENV CCL_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch +ENV I_MPI_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch +ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric +ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/lib +ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/" + +# Install server +COPY proto proto +COPY server server +COPY server/Makefile server/Makefile +RUN cd server && \ + make gen-server && \ + pip install -r requirements_intel.txt && \ + pip install ".[accelerate, peft, outlines]" --no-cache-dir + +# Install benchmarker +COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark +# Install router +COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router +# Install launcher +COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher + +FROM ${PLATFORM} AS final +ENV ATTENTION=paged +ENV USE_PREFIX_CACHING=0 +ENV CUDA_GRAPHS=0 ENTRYPOINT ["text-generation-launcher"] CMD ["--json-output"] diff --git a/Makefile b/Makefile index 96e67e2b..664b869d 100644 --- a/Makefile +++ b/Makefile @@ -1,20 +1,22 @@ install-server: cd server && make install -install-integration-tests: - cd integration-tests && pip install -r requirements.txt - cd clients/python && pip install . +install-server-cpu: + cd server && make install-server install-router: - cd router && cargo install --locked --path . + cargo install --path backends/v3/ install-launcher: - cd launcher && cargo install --locked --path . + cargo install --path launcher/ install-benchmark: - cd benchmark && cargo install --locked --path . + cargo install --path benchmark/ -install: install-server install-router install-launcher install-custom-kernels +install: install-server install-router install-launcher + + +install-cpu: install-server-cpu install-router install-launcher server-dev: cd server && make run-dev @@ -25,6 +27,10 @@ router-dev: rust-tests: install-router install-launcher cargo test +install-integration-tests: + cd integration-tests && pip install -r requirements.txt + cd clients/python && pip install . + integration-tests: install-integration-tests pytest -s -vv -m "not private" integration-tests @@ -44,6 +50,3 @@ run-falcon-7b-instruct: clean: rm -rf target aml - -debug_image_build: - docker build --no-cache --progress=plain -t debug_tgi . diff --git a/router/client/Cargo.toml b/backends/client/Cargo.toml similarity index 89% rename from router/client/Cargo.toml rename to backends/client/Cargo.toml index bc4ae72e..db423c4b 100644 --- a/router/client/Cargo.toml +++ b/backends/client/Cargo.toml @@ -6,10 +6,11 @@ authors.workspace = true homepage.workspace = true [dependencies] +async-trait = "^0.1" +base64 = { workspace = true } futures = "^0.3" grpc-metadata = { path = "../grpc-metadata" } prost = "^0.12" -rand = "0.8.5" thiserror = "^1.0" tokio = { version = "^1.32", features = ["sync"] } tonic = "^0.10" diff --git a/backends/client/build.rs b/backends/client/build.rs new file mode 100644 index 00000000..210cd603 --- /dev/null +++ b/backends/client/build.rs @@ -0,0 +1,35 @@ +use std::fs; + +fn main() -> Result<(), Box> { + println!("cargo:rerun-if-changed=../../proto/"); + + fs::create_dir_all("src/v2/pb").unwrap_or(()); + let mut config = prost_build::Config::new(); + config.protoc_arg("--experimental_allow_proto3_optional"); + + tonic_build::configure() + .build_client(true) + .build_server(false) + .out_dir("src/v2/pb") + .include_file("mod.rs") + .compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"]) + .map_err(|e| match e.kind(){ + std::io::ErrorKind::NotFound => {panic!("`protoc` not found, install libprotoc")}, + std::io::ErrorKind::Other => {panic!("`protoc` version unsupported, upgrade protoc: https://github.com/protocolbuffers/protobuf/releases")}, + e => {e} + }).unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); + + fs::create_dir_all("src/v3/pb").unwrap_or(()); + let mut config = prost_build::Config::new(); + config.protoc_arg("--experimental_allow_proto3_optional"); + + tonic_build::configure() + .build_client(true) + .build_server(false) + .out_dir("src/v3/pb") + .include_file("mod.rs") + .compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"]) + .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); + + Ok(()) +} diff --git a/backends/client/src/lib.rs b/backends/client/src/lib.rs new file mode 100644 index 00000000..45bee10c --- /dev/null +++ b/backends/client/src/lib.rs @@ -0,0 +1,91 @@ +//! Text Generation gRPC client library + +pub mod v2; +pub mod v3; + +use async_trait::async_trait; +use base64::{engine::general_purpose::STANDARD, Engine}; +use thiserror::Error; +use tonic::transport; +use tonic::Status; + +pub use v3::{Chunk, Image, Input, InputChunk}; + +#[async_trait] +pub trait Health { + /// Check if a generate server is healthy by asking it to allocate a tensor on device + async fn device_health(&self) -> Result<()>; + + /// Check if a generate server is healthy by doing a forward pass. + /// EXPENSIVE + async fn model_health(&self) -> Result<()>; +} + +#[derive(Debug)] +pub struct ShardInfo { + pub requires_padding: bool, + pub dtype: String, + pub device_type: String, + pub window_size: Option, + pub speculate: u32, +} + +#[derive(Error, Debug, Clone)] +pub enum ClientError { + #[error("Could not connect to Text Generation server: {0}")] + Connection(String), + #[error("Server error: {0}")] + Generation(String), + #[error("Sharded results are empty")] + EmptyResults, +} + +impl From for ClientError { + fn from(err: Status) -> Self { + let err = Self::Generation(err.message().to_string()); + tracing::error!("{err}"); + err + } +} + +impl From for ClientError { + fn from(err: transport::Error) -> Self { + let err = Self::Connection(err.to_string()); + tracing::error!("{err}"); + err + } +} + +// Small convenience re-wrapping of `Chunk`. +impl From for InputChunk { + fn from(chunk: Chunk) -> Self { + InputChunk { chunk: Some(chunk) } + } +} + +/// Convert input chunks to a stringly-typed input for backwards +/// compat for backends that haven't implemented chunked inputs. +pub trait ChunksToString { + /// Convert chunks to string. + fn chunks_to_string(&self) -> String; +} + +impl ChunksToString for Vec { + fn chunks_to_string(&self) -> String { + let mut output = String::new(); + self.iter().for_each(|c| match &c.chunk { + Some(Chunk::Text(text)) => output.push_str(text), + Some(Chunk::Image(Image { data, mimetype })) => { + let encoded = STANDARD.encode(data); + output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded)) + } + // We don't create empty chunks, so this should be unreachable. + None => unreachable!("Chunks should never be empty"), + }); + output + } +} + +static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; + +pub type Result = std::result::Result; diff --git a/backends/client/src/v2/client.rs b/backends/client/src/v2/client.rs new file mode 100644 index 00000000..7a6922f0 --- /dev/null +++ b/backends/client/src/v2/client.rs @@ -0,0 +1,260 @@ +/// Single shard Client +use crate::v2::pb; +use crate::{ClientError, Result}; + +use crate::WARMUP_IMAGE_BASE64; +use grpc_metadata::InjectTelemetryContext; +use pb::generate::v2::text_generation_service_client::TextGenerationServiceClient; +use pb::generate::v2::*; +use std::cmp::min; +use std::time::Duration; +use tonic::transport::{Channel, Uri}; +use tracing::instrument; + +/// Text Generation Inference gRPC client +#[derive(Debug, Clone)] +pub struct Client { + stub: TextGenerationServiceClient, +} + +impl Client { + /// Returns a client connected to the given url + pub async fn connect(uri: Uri) -> Result { + let channel = Channel::builder(uri).connect().await?; + + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) + } + + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { + let channel = Channel::from_shared("http://[::]:50051".to_string()) + .unwrap() + .connect_with_connector(tower::service_fn(move |_: Uri| { + tokio::net::UnixStream::connect(path.clone()) + })) + .await?; + + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) + } + + /// Returns a list of uris or unix sockets of all shards + #[instrument(skip(self))] + pub async fn service_discovery(&mut self) -> Result> { + let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context(); + let response = self.stub.service_discovery(request).await.map_err(|_| { + ClientError::Connection("Server does not support v2 interface".to_string()) + })?; + let urls = response + .into_inner() + .urls + .into_iter() + // Remove unix socket prefix + .map(|url| match url.strip_prefix("unix://") { + None => url, + Some(stripped_url) => stripped_url.to_string(), + }) + .collect(); + Ok(urls) + } + + /// Get model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let request = tonic::Request::new(InfoRequest {}).inject_context(); + let response = self.stub.info(request).await?.into_inner(); + Ok(response) + } + + /// Get model health + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let request = tonic::Request::new(HealthRequest {}).inject_context(); + let response = self.stub.health(request).await?.into_inner(); + Ok(response) + } + + /// Clear the past generations cache + #[instrument(skip(self))] + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { + let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context(); + self.stub.clear_cache(request).await?; + Ok(()) + } + + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + request_ids: Vec, + ) -> Result> { + let request = tonic::Request::new(FilterBatchRequest { + batch_id, + request_ids, + }) + .inject_context(); + let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); + Ok(filtered_batch.batch) + } + + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip_all)] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + max_batch_total_tokens: u32, + max_batch_size: Option, + ) -> Result> { + let mut n_tokens = 0; + let mut requests = Vec::new(); + // Create requests + while n_tokens < max_prefill_tokens { + let truncate = min(max_input_length, max_prefill_tokens - n_tokens); + + let mut inputs = String::new(); + inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); + if n_tokens == 0 { + // 1 request is enough to test vision heads. + // Sending images on other queries messes up easily with truncation. + inputs.push_str(&format!( + "![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})", + )); + } + + requests.push(Request { + id: 0, + inputs, + // We truncate the input on the server side to be sure that it has the correct size + truncate, + // Set sampling parameters to also take these ops into account in the max memory + parameters: Some(NextTokenChooserParameters { + temperature: 0.9, + top_k: 10, + top_p: 0.9, + typical_p: 0.9, + do_sample: false, + seed: 0, + repetition_penalty: 1.2, + frequency_penalty: 0.1, + watermark: true, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: max_total_tokens - truncate, + stop_sequences: vec![], + ignore_eos_token: true, + }), + prefill_logprobs: true, + top_n_tokens: 20, + }); + n_tokens += max_input_length; + + // Check max_batch_size + if Some(requests.len()) == max_batch_size { + break; + } + } + + let batch = Batch { + id: 0, + size: requests.len() as u32, + requests, + max_tokens: 0, + }; + + let request = tonic::Request::new(WarmupRequest { + batch: Some(batch), + max_input_length, + max_prefill_tokens, + max_total_tokens, + max_batch_total_tokens, + }) + .inject_context(); + let response = self.stub.warmup(request).await?.into_inner(); + Ok(response.max_supported_total_tokens) + } + + /// Generate one token for each request in the given batch + /// + /// Returns Generation for each request in batch + /// and the next cached batch + #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))] + pub async fn prefill( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option, PrefillTimings)> { + let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); + let response = self.stub.prefill(request).await?.into_inner(); + Ok(( + response.generations, + response.batch, + PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns), + )) + } + + /// Generate one token for each request in the given cached batches + /// + /// Returns Generation for each request in batches + /// and the next cached batch + #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::()))] + pub async fn decode( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option, DecodeTimings)> { + let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); + let response = self.stub.decode(request).await?.into_inner(); + Ok(( + response.generations, + response.batch, + DecodeTimings::new( + response.concat_ns, + response.forward_ns, + response.decode_ns, + response.total_ns, + ), + )) + } +} + +pub struct PrefillTimings { + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl PrefillTimings { + fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } + } +} + +pub struct DecodeTimings { + pub concat: Option, + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl DecodeTimings { + fn new(concat_ns: Option, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + concat: concat_ns.map(Duration::from_nanos), + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } + } +} diff --git a/backends/client/src/v2/mod.rs b/backends/client/src/v2/mod.rs new file mode 100644 index 00000000..6b14b9f3 --- /dev/null +++ b/backends/client/src/v2/mod.rs @@ -0,0 +1,13 @@ +#[allow(clippy::derive_partial_eq_without_eq)] +mod pb; + +mod client; +mod sharded_client; + +pub use client::Client; +pub use pb::generate::v2::HealthResponse; +pub use pb::generate::v2::{ + Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, InfoResponse, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens, +}; +pub use sharded_client::ShardedClient; diff --git a/router/client/src/sharded_client.rs b/backends/client/src/v2/sharded_client.rs similarity index 75% rename from router/client/src/sharded_client.rs rename to backends/client/src/v2/sharded_client.rs index fdd84035..2709ea88 100644 --- a/router/client/src/sharded_client.rs +++ b/backends/client/src/v2/sharded_client.rs @@ -1,12 +1,17 @@ -/// Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - -use crate::client::{DecodeTimings, PrefillTimings}; /// Multi shard Client -use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo}; +use crate::{v2, Health, ShardInfo}; use crate::{ClientError, Result}; + +use crate::v2::InfoResponse; +use async_trait::async_trait; use futures::future::join_all; use tonic::transport::Uri; use tracing::instrument; +use v2::client::{DecodeTimings, PrefillTimings}; +use v2::{ + Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, +}; #[derive(Debug, Clone)] /// Text Generation Inference gRPC multi client @@ -49,7 +54,7 @@ impl ShardedClient { .iter_mut() .map(|client| client.info()) .collect(); - join_all(futures).await.pop().unwrap() + join_all(futures).await.pop().unwrap().map(ShardInfo::from) } /// GRPC health check @@ -99,8 +104,8 @@ impl ShardedClient { max_input_length: u32, max_prefill_tokens: u32, max_total_tokens: u32, + max_batch_total_tokens: u32, max_batch_size: Option, - model_id: &str, ) -> Result> { let futures: Vec<_> = self .clients @@ -110,8 +115,8 @@ impl ShardedClient { max_input_length, max_prefill_tokens, max_total_tokens, + max_batch_total_tokens, max_batch_size, - model_id )) }) .collect(); @@ -189,3 +194,60 @@ impl ShardedClient { Ok((generations, next_batch, timings)) } } + +impl From for ShardInfo { + fn from(value: InfoResponse) -> Self { + Self { + requires_padding: value.requires_padding, + dtype: value.dtype, + device_type: value.device_type, + window_size: value.window_size, + speculate: value.speculate, + } + } +} + +#[async_trait] +impl Health for ShardedClient { + async fn device_health(&self) -> Result<()> { + self.clone().health().await?; + Ok(()) + } + + async fn model_health(&self) -> Result<()> { + // Dummy batch of 1 token and 1 generated token + let liveness_request = Request { + id: u64::MAX, + inputs: "liveness".to_string(), + truncate: 10, + prefill_logprobs: false, + parameters: Some(NextTokenChooserParameters { + temperature: 1.0, + top_k: 0, + top_p: 1.0, + typical_p: 1.0, + do_sample: false, + seed: 0, + repetition_penalty: 1.0, + frequency_penalty: 0.0, + watermark: false, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: 1, + stop_sequences: vec![], + ignore_eos_token: false, + }), + top_n_tokens: 0, + }; + let batch = Batch { + id: u64::MAX, + requests: vec![liveness_request], + size: 1, + max_tokens: 2, + }; + self.clone().prefill(batch).await?; + Ok(()) + } +} diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs new file mode 100644 index 00000000..6274c359 --- /dev/null +++ b/backends/client/src/v3/client.rs @@ -0,0 +1,288 @@ +use crate::v3::{pb, Chunk}; +use crate::{ClientError, Result, WARMUP_IMAGE_BASE64}; +/// Single shard Client +use base64::engine::general_purpose::STANDARD; +use base64::Engine; +use grpc_metadata::InjectTelemetryContext; +use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient; +use pb::generate::v3::*; +use std::cmp::min; +use std::time::Duration; +use tonic::transport::{Channel, Uri}; +use tracing::instrument; + +/// Text Generation Inference gRPC client +#[derive(Debug, Clone)] +pub struct Client { + stub: TextGenerationServiceClient, +} + +impl Client { + /// Returns a client connected to the given url + pub async fn connect(uri: Uri) -> Result { + let channel = Channel::builder(uri).connect().await?; + + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) + } + + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { + let channel = Channel::from_shared("http://[::]:50051".to_string()) + .unwrap() + .connect_with_connector(tower::service_fn(move |_: Uri| { + tokio::net::UnixStream::connect(path.clone()) + })) + .await?; + + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) + } + + /// Returns a list of uris or unix sockets of all shards + #[instrument(skip(self))] + pub async fn service_discovery(&mut self) -> Result> { + let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context(); + let response = self.stub.service_discovery(request).await.map_err(|_| { + ClientError::Connection("Server does not support v3 interface".to_string()) + })?; + let urls = response + .into_inner() + .urls + .into_iter() + // Remove unix socket prefix + .map(|url| match url.strip_prefix("unix://") { + None => url, + Some(stripped_url) => stripped_url.to_string(), + }) + .collect(); + Ok(urls) + } + + /// Get model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let request = tonic::Request::new(InfoRequest {}).inject_context(); + let response = self.stub.info(request).await?.into_inner(); + Ok(response) + } + + /// Get model health + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let request = tonic::Request::new(HealthRequest {}).inject_context(); + let response = self.stub.health(request).await?.into_inner(); + Ok(response) + } + + /// Clear the past generations cache + #[instrument(skip(self))] + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { + let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context(); + self.stub.clear_cache(request).await?; + Ok(()) + } + + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + request_ids: Vec, + ) -> Result> { + let request = tonic::Request::new(FilterBatchRequest { + batch_id, + request_ids, + }) + .inject_context(); + let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); + Ok(filtered_batch.batch) + } + + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip_all)] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + max_batch_total_tokens: u32, + max_batch_size: Option, + ) -> Result> { + let mut n_tokens = 0; + let mut requests = Vec::new(); + // Create requests + while n_tokens < max_prefill_tokens { + let truncate = min(max_input_length, max_prefill_tokens - n_tokens); + + let mut input_chunks = Vec::new(); + input_chunks + .push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into()); + if n_tokens == 0 { + input_chunks.push( + Chunk::Image(Image { + // Safe unwrap, because we control the data. + data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(), + mimetype: "image/jpeg;base64".to_string(), + }) + .into(), + ); + } + + // Send stringly-typed inputs for compatibility for backends that haven't + // been updated to support chunks. + + let mut inputs = String::new(); + inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); + if n_tokens == 0 { + // 1 request is enough to test vision heads. + // Sending images on other queries messes up easily with truncation. + inputs.push_str(&format!( + "![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})", + )); + } + + requests.push(Request { + id: 0, + inputs, + input_chunks: Some(Input { + chunks: input_chunks, + }), + // We truncate the input on the server side to be sure that it has the correct size + truncate, + // Most request will have that + add_special_tokens: true, + // Blocks and slots will be set on the server side if we use paged attention + blocks: vec![], + slots: vec![], + prefix_len: 0, + // Set sampling parameters to also take these ops into account in the max memory + parameters: Some(NextTokenChooserParameters { + temperature: 0.9, + top_k: 10, + top_p: 0.9, + typical_p: 0.9, + do_sample: false, + seed: 0, + repetition_penalty: 1.2, + frequency_penalty: 0.1, + watermark: true, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: max_total_tokens - truncate, + stop_sequences: vec![], + ignore_eos_token: true, + }), + prefill_logprobs: true, + top_n_tokens: 20, + adapter_id: None, + }); + n_tokens += max_input_length; + + // Check max_batch_size + if Some(requests.len()) == max_batch_size { + break; + } + } + + let batch = Batch { + id: 0, + size: requests.len() as u32, + requests, + max_tokens: max_input_length, + max_blocks: 0, + }; + + let request = tonic::Request::new(WarmupRequest { + batch: Some(batch), + max_input_length, + max_prefill_tokens, + max_total_tokens, + max_batch_total_tokens, + }) + .inject_context(); + let response = self.stub.warmup(request).await?.into_inner(); + Ok(response.max_supported_total_tokens) + } + + /// Generate one token for each request in the given batch + /// + /// Returns Generation for each request in batch + /// and the next cached batch + #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))] + pub async fn prefill( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option, PrefillTimings)> { + let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); + let response = self.stub.prefill(request).await?.into_inner(); + Ok(( + response.generations, + response.batch, + PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns), + )) + } + + /// Generate one token for each request in the given cached batches + /// + /// Returns Generation for each request in batches + /// and the next cached batch + #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::()))] + pub async fn decode( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option, DecodeTimings)> { + let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); + let response = self.stub.decode(request).await?.into_inner(); + Ok(( + response.generations, + response.batch, + DecodeTimings::new( + response.concat_ns, + response.forward_ns, + response.decode_ns, + response.total_ns, + ), + )) + } +} + +pub struct PrefillTimings { + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl PrefillTimings { + fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } + } +} + +pub struct DecodeTimings { + pub concat: Option, + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl DecodeTimings { + fn new(concat_ns: Option, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + concat: concat_ns.map(Duration::from_nanos), + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } + } +} diff --git a/backends/client/src/v3/mod.rs b/backends/client/src/v3/mod.rs new file mode 100644 index 00000000..4a1296a2 --- /dev/null +++ b/backends/client/src/v3/mod.rs @@ -0,0 +1,13 @@ +#[allow(clippy::derive_partial_eq_without_eq)] +mod pb; + +mod client; +mod sharded_client; + +pub use client::Client; +pub use pb::generate::v3::{ + input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, + HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, + StoppingCriteriaParameters, Tokens, +}; +pub use sharded_client::ShardedClient; diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs new file mode 100644 index 00000000..9709dfcb --- /dev/null +++ b/backends/client/src/v3/sharded_client.rs @@ -0,0 +1,263 @@ +/// Multi shard Client +use crate::{v3, Health, ShardInfo}; +use crate::{ClientError, Result}; + +use crate::v3::{Chunk, InfoResponse, Input}; +use async_trait::async_trait; +use futures::future::join_all; +use tonic::transport::Uri; +use tracing::instrument; +use v3::client::{DecodeTimings, PrefillTimings}; +use v3::{ + Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, +}; + +#[derive(Debug, Clone)] +/// Text Generation Inference gRPC multi client +pub struct ShardedClient { + clients: Vec, +} + +impl ShardedClient { + fn new(clients: Vec) -> Self { + Self { clients } + } + + /// Create a new ShardedClient from a master client. The master client will communicate with + /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method. + async fn from_master_client(mut master_client: Client) -> Result { + // Get all uris/unix sockets from the master client + let uris = master_client.service_discovery().await?; + let futures = uris.into_iter().map(Client::connect_uds); + let clients: Result> = join_all(futures).await.into_iter().collect(); + Ok(Self::new(clients?)) + } + + /// Returns a client connected to the given uri + pub async fn connect(uri: Uri) -> Result { + let master_client = Client::connect(uri).await?; + Self::from_master_client(master_client).await + } + + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { + let master_client = Client::connect_uds(path).await?; + Self::from_master_client(master_client).await + } + + /// Get the model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.info()) + .collect(); + join_all(futures).await.pop().unwrap().map(ShardInfo::from) + } + + /// GRPC health check + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.health()) + .collect(); + join_all(futures).await.pop().unwrap() + } + + /// Clear the past generations cache + #[instrument(skip(self))] + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.clear_cache(batch_id)) + .collect(); + join_all(futures).await.into_iter().collect() + } + + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + request_ids: Vec, + ) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone()))) + .collect(); + // all shards return the same message + join_all(futures).await.pop().unwrap() + } + + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip(self))] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + max_batch_total_tokens: u32, + max_batch_size: Option, + ) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| { + Box::pin(client.warmup( + max_input_length, + max_prefill_tokens, + max_total_tokens, + max_batch_total_tokens, + max_batch_size, + )) + }) + .collect(); + // Take the minimum value + let results = join_all(futures) + .await + .into_iter() + .collect::>>>()?; + Ok(results.into_iter().flatten().min()) + } + + /// Generate one token for each request in the given batch + /// + /// Returns Generation for each request in batch + /// and the next cached batch + #[instrument(skip_all, fields(id = & batch.id, size = & batch.size))] + pub async fn prefill( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option, PrefillTimings)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.prefill(batch.clone()))) + .collect(); + #[allow(clippy::type_complexity)] + let results: Result, Option, PrefillTimings)>> = + join_all(futures).await.into_iter().collect(); + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) + } + + /// Generate one token for each request in the given cached batches + /// + /// Returns Generation for each request in batches + /// and the next cached batch + #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))] + pub async fn decode( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option, DecodeTimings)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.decode(batches.clone()))) + .collect(); + #[allow(clippy::type_complexity)] + let results: Result, Option, DecodeTimings)>> = + join_all(futures).await.into_iter().collect(); + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) + } +} + +impl From for ShardInfo { + fn from(value: InfoResponse) -> Self { + Self { + requires_padding: value.requires_padding, + dtype: value.dtype, + device_type: value.device_type, + window_size: value.window_size, + speculate: value.speculate, + } + } +} + +#[async_trait] +impl Health for ShardedClient { + async fn device_health(&self) -> Result<()> { + self.clone().health().await?; + Ok(()) + } + + async fn model_health(&self) -> Result<()> { + // Dummy batch of 1 token and 1 generated token + let liveness_request = Request { + id: u64::MAX, + inputs: "liveness".to_string(), + input_chunks: Some(Input { + chunks: vec![Chunk::Text("liveness".into()).into()], + }), + truncate: 10, + add_special_tokens: true, + prefill_logprobs: false, + parameters: Some(NextTokenChooserParameters { + temperature: 1.0, + top_k: 0, + top_p: 1.0, + typical_p: 1.0, + do_sample: false, + seed: 0, + repetition_penalty: 1.0, + frequency_penalty: 0.0, + watermark: false, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: 1, + stop_sequences: vec![], + ignore_eos_token: false, + }), + top_n_tokens: 0, + // Block 0 is reserved for health checks + blocks: vec![0], + slots: (0..16).collect(), + prefix_len: 0, + adapter_id: None, + }; + let batch = Batch { + id: u64::MAX, + requests: vec![liveness_request], + size: 1, + max_tokens: 2, + max_blocks: 1, + }; + self.clone().prefill(batch).await?; + Ok(()) + } +} diff --git a/router/grpc-metadata/Cargo.toml b/backends/grpc-metadata/Cargo.toml similarity index 100% rename from router/grpc-metadata/Cargo.toml rename to backends/grpc-metadata/Cargo.toml diff --git a/router/grpc-metadata/src/lib.rs b/backends/grpc-metadata/src/lib.rs similarity index 100% rename from router/grpc-metadata/src/lib.rs rename to backends/grpc-metadata/src/lib.rs diff --git a/backends/trtllm/CMakeLists.txt b/backends/trtllm/CMakeLists.txt new file mode 100644 index 00000000..425b2d7b --- /dev/null +++ b/backends/trtllm/CMakeLists.txt @@ -0,0 +1,63 @@ +cmake_minimum_required(VERSION 3.20) + +project(tgi-trtllm-backend VERSION 1.0.0) +set(CMAKE_CXX_STANDARD 20) + +include(FetchContent) +include(ExternalProject) + +option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF) +option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF) +set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support") +set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located") +set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located") +set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located") + +# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features +find_package(CUDAToolkit 12.5 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml) + +#### External dependencies #### +include(cmake/fmt.cmake) +include(cmake/json.cmake) +include(cmake/spdlog.cmake) +include(cmake/trtllm.cmake) + +# Let's build TRTLLM as part of CMake +add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..") + +# Tell CMake to need try to override the RPATH for executorWorker as it has not information on how to do so +set_target_properties(executorWorker PROPERTIES SKIP_BUILD_RPATH TRUE) + +# TGI TRTLLM Backend definition +add_library(tgi_trtllm_backend_impl STATIC include/backend.h lib/backend.cpp include/hardware.h) +include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR}) +target_include_directories(tgi_trtllm_backend_impl PRIVATE + $ + $ +) +target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include") +target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper CUDA::cudart CUDA::nvml) +target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt) + +# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back +install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker) +install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB) + +#### Unit Tests #### +if (${TGI_TRTLLM_BACKEND_BUILD_TESTS}) + message(STATUS "Building tests") + FetchContent_Declare( + Catch2 + GIT_REPOSITORY https://github.com/catchorg/Catch2 + GIT_TAG v3.6.0 + ) + FetchContent_MakeAvailable(Catch2) + + # add_executable(tgi_trtllm_backend_tests tests/infer_test.cpp) + # target_link_libraries(tgi_trtllm_backend_tests PRIVATE tgi_trtllm_backend_impl Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt CUDA::cudart CUDA::nvml) + + list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras) + include(CTest) + include(Catch) + # catch_discover_tests(tgi_trtllm_backend_tests) +endif () diff --git a/backends/trtllm/Cargo.toml b/backends/trtllm/Cargo.toml new file mode 100644 index 00000000..43a114ba --- /dev/null +++ b/backends/trtllm/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "text-generation-backends-trtllm" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[dependencies] +async-trait = "0.1" +async-stream = "0.3" +clap = { version = "4.5", features = ["derive"] } +cxx = "1.0" +log = { version = "0.4", features = [] } +text-generation-router = { path = "../../router" } +tokenizers = { version = "0.19", features = ["hf-hub"] } +tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokio-stream = "0.1.15" +thiserror = "1.0.62" +tracing = "0.1" +tracing-opentelemetry = "0.24" +tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } +parking_lot = "0.12" + +[build-dependencies] +cmake = "0.1" +cxx-build = { version = "1.0", features = ["parallel"] } +pkg-config = "0.3" diff --git a/backends/trtllm/Dockerfile b/backends/trtllm/Dockerfile new file mode 100644 index 00000000..5fd2f89f --- /dev/null +++ b/backends/trtllm/Dockerfile @@ -0,0 +1,101 @@ +ARG CUDA_ARCH_LIST="75-real;80-real;86-real;89-real;90-real" +ARG OMPI_VERSION="4.1.6" + +# Build dependencies resolver stage +FROM lukemathwalker/cargo-chef:latest AS chef +WORKDIR /usr/src/text-generation-inference/backends/trtllm + +FROM chef AS planner +COPY . . +RUN cargo chef prepare --recipe-path recipe.json + +# CUDA dependent dependencies resolver stage +FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu22.04 AS cuda-builder + +RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ + --mount=type=cache,target=/var/lib/apt,sharing=locked \ + apt update && apt install -y \ + build-essential \ + cmake \ + curl \ + gcc \ + g++ \ + git \ + git-lfs \ + libssl-dev \ + ninja-build \ + pkg-config \ + python3 \ + python3-setuptools \ + tar \ + wget + +ENV TGI_INSTALL_PREFIX=/usr/local/tgi +ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt + +# Install OpenMPI +FROM cuda-builder AS mpi-builder +ARG OMPI_VERSION + +ENV OMPI_TARBALL_FILENAME="openmpi-$OMPI_VERSION.tar.bz2" +RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILENAME" -P /opt/src && \ + mkdir /usr/src/mpi && \ + tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \ + cd /usr/src/mpi && \ + ./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda && \ + make -j all && \ + make install && \ + rm -rf "/opt/src/$OMPI_TARBALL_FILENAME" + +# Install TensorRT +FROM cuda-builder AS trt-builder +COPY backends/trtllm/scripts/install_tensorrt.sh /opt/install_tensorrt.sh +RUN chmod +x /opt/install_tensorrt.sh && \ + /opt/install_tensorrt.sh + +# Build Backend +FROM cuda-builder AS tgi-builder +WORKDIR /usr/src/text-generation-inference + +# Install Rust +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \ + chmod -R a+w /root/.rustup && \ + chmod -R a+w /root/.cargo + +ENV PATH="/root/.cargo/bin:$PATH" +RUN cargo install cargo-chef + +# Cache dependencies +COPY --from=planner /usr/src/text-generation-inference/backends/trtllm/recipe.json . +RUN cargo chef cook --release --recipe-path recipe.json + +# Build actual TGI +ARG CUDA_ARCH_LIST +ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt:$CMAKE_PREFIX_PATH" +ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH" +ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig:$PKG_CONFIG_PATH" + +COPY . . +COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt +COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi +RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \ + cd backends/trtllm && \ + CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release + +FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04 AS runtime +WORKDIR /usr/local/tgi/bin + +ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH" + +COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi +COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt +COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi +COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher + +FROM runtime + +LABEL co.huggingface.vendor="Hugging Face Inc." +LABEL org.opencontainers.image.authors="hardware@hf.co" + +ENTRYPOINT ["./text-generation-launcher"] +CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"] diff --git a/backends/trtllm/README.md b/backends/trtllm/README.md new file mode 100644 index 00000000..94064504 --- /dev/null +++ b/backends/trtllm/README.md @@ -0,0 +1,46 @@ +# Text Generation Inference - TensorRT-LLM Backend Implementation + +## Description + +This folder provides the sources of the TensorRT-LLM backend implementation powered by TensorRT-LLM Executor new API + +## Simplified Request Sequence + +```mermaid +sequenceDiagram + actor User + participant TextGenerationInference.HttpServer + participant TextGenerationInference.TensorRtLlmBackend + participant TextGenerationInference.TensorRtLlmWorkerThread + participant TensorRtLlm.Executor + participant Nvidia.Gpu + User ->> TextGenerationInference.HttpServer: POST /generate + TextGenerationInference.HttpServer ->> TextGenerationInference.TensorRtLlmBackend: Validate and forward inputs & parameters + TextGenerationInference.TensorRtLlmBackend ->> TextGenerationInference.TensorRtLlmWorkerThread: Allocate a new context and spawn a new thread to handle the request + TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Submit the request to the In-Flight Batcher + activate Nvidia.Gpu + TensorRtLlm.Executor ->> Nvidia.Gpu: Add the request to the poll for execution + TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Response with an unique request identifier + rect rgb(10, 92, 54) + loop every 100us + rect rgb(15, 81, 50) + alt Acquire lock to query executor + TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Poll request number of new token(s) generated + else There are new generated tokens + TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Retrieve newly generated tokens + TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Return decoded token information and potential error (omitted) + rect rgb(11, 110, 79) + alt Generated token is final + TensorRtLlm.Executor ->> Nvidia.Gpu: Remove request from the scheduler and from the GPU + TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream the remaining decoded tokens and flush the connection + else Generated token is not final + TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream token back to the user as they get decoded + end + end + end + end + deactivate Nvidia.Gpu + end + end + +``` diff --git a/backends/trtllm/build.rs b/backends/trtllm/build.rs new file mode 100644 index 00000000..08638262 --- /dev/null +++ b/backends/trtllm/build.rs @@ -0,0 +1,150 @@ +use cxx_build::CFG; +use pkg_config; +use std::env; +use std::env::consts::ARCH; +use std::path::{absolute, PathBuf}; + +const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"]; +const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST"); +const CUDA_REQUIRED_VERSION: &str = "12.5"; +const MPI_REQUIRED_VERSION: &str = "4.1"; +const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX"); +const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR"); +const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR"); + +// Dependencies +const BACKEND_DEPS: [&str; 2] = ["tgi_trtllm_backend_impl", "tgi_trtllm_backend"]; +const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"]; +const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [ + ("dylib", "tensorrt_llm"), + ("static", "tensorrt_llm_executor_static"), + ("dylib", "tensorrt_llm_nvrtc_wrapper"), + ("dylib", "nvinfer_plugin_tensorrt_llm"), + ("dylib", "decoder_attention"), +]; + +macro_rules! probe { + ($name: expr, $version: expr) => { + if let Err(_) = pkg_config::probe_library($name) { + pkg_config::probe_library(&format!("{}-{}", $name, $version)) + .expect(&format!("Failed to locate {}", $name)); + } + }; +} + +fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) { + // Build the backend implementation through CMake + let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi"); + let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt"); + let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("90-real"); // Hopper by default + + let mut install_path = PathBuf::from(install_path); + if !install_path.is_absolute() { + install_path = absolute(out_dir).expect("cannot happen").join(install_path); + } + + let _ = cmake::Config::new(".") + .uses_cxx11() + .generator("Ninja") + .profile(match is_debug { + true => "Debug", + false => "Release", + }) + .env("OPT_LEVEL", opt_level) + .define("CMAKE_INSTALL_PREFIX", &install_path) + .define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc") + .define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list) + .define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path) + .build(); + + // Additional transitive CMake dependencies + let deps_folder = out_dir.join("build").join("_deps"); + for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES { + let dep_name = match is_debug { + true => format!("{}d", dependency), + false => String::from(dependency), + }; + let dep_path = deps_folder.join(format!("{}-build", dependency)); + println!("cargo:rustc-link-search={}", dep_path.display()); + println!("cargo:rustc-link-lib=static={}", dep_name); + } + + // Emit linkage information from the artifacts we just built + let install_lib_path = install_path.join("lib"); + + println!( + r"cargo:warning=Adding link search path: {}", + install_lib_path.display() + ); + println!(r"cargo:rustc-link-search={}", install_lib_path.display()); + + (PathBuf::from(install_path), deps_folder) +} + +fn build_ffi_layer(deps_folder: &PathBuf) { + CFG.include_prefix = "backends/trtllm"; + cxx_build::bridge("src/lib.rs") + .static_flag(true) + .include(deps_folder.join("fmt-src").join("include")) + .include(deps_folder.join("spdlog-src").join("include")) + .include(deps_folder.join("json-src").join("include")) + .include(deps_folder.join("trtllm-src").join("cpp").join("include")) + .include("/usr/local/cuda/include") + .include("/usr/local/tensorrt/include") + .file("src/ffi.cpp") + .std("c++20") + .compile("tgi_trtllm_backend"); + + println!("cargo:rerun-if-changed=CMakeLists.txt"); + println!("cargo:rerun-if-changed=include/backend.h"); + println!("cargo:rerun-if-changed=lib/backend.cpp"); + println!("cargo:rerun-if-changed=include/ffi.h"); + println!("cargo:rerun-if-changed=src/ffi.cpp"); +} + +fn main() { + // Misc variables + let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); + let build_profile = env::var("PROFILE").unwrap(); + let (is_debug, opt_level) = match build_profile.as_ref() { + "debug" => (true, "0"), + _ => (false, "3"), + }; + + // Build the backend + let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir); + + // Build the FFI layer calling the backend above + build_ffi_layer(&deps_folder); + + // Emit linkage search path + probe!("ompi", MPI_REQUIRED_VERSION); + + // Probe CUDA & co. with pkg-config + CUDA_TRANSITIVE_DEPS.iter().for_each(|name| { + probe!(name, CUDA_REQUIRED_VERSION); + }); + + // NCCL is slightly trickier because it might not have a pkgconfig installed + let nccl_library_path_default = format!("/usr/local/{}-linux-gnu", ARCH); + let nccl_library_path = NCCL_ROOT_DIR.unwrap_or(&nccl_library_path_default); + println!(r"cargo:rustc-link-search=native={}", nccl_library_path); + println!("cargo:rustc-link-lib=dylib=nccl"); + + // TensorRT + let tensort_library_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt/lib"); + println!(r"cargo:rustc-link-search=native={}", tensort_library_path); + println!("cargo:rustc-link-lib=dylib=nvinfer"); + + // TensorRT-LLM + TENSORRT_LLM_TRANSITIVE_DEPS + .iter() + .for_each(|(link_type, name)| { + println!("cargo:rustc-link-lib={}={}", link_type, name); + }); + + // Backend + BACKEND_DEPS.iter().for_each(|name| { + println!("cargo:rustc-link-lib=static={}", name); + }); +} diff --git a/backends/trtllm/cmake/fmt.cmake b/backends/trtllm/cmake/fmt.cmake new file mode 100644 index 00000000..f94a9c56 --- /dev/null +++ b/backends/trtllm/cmake/fmt.cmake @@ -0,0 +1,6 @@ +FetchContent_Declare( + fmt + GIT_REPOSITORY https://github.com/fmtlib/fmt + GIT_TAG 11.0.1 +) +FetchContent_MakeAvailable(fmt) diff --git a/backends/trtllm/cmake/json.cmake b/backends/trtllm/cmake/json.cmake new file mode 100644 index 00000000..29e5753b --- /dev/null +++ b/backends/trtllm/cmake/json.cmake @@ -0,0 +1,5 @@ +fetchcontent_declare( + json + URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz +) +fetchcontent_makeavailable(json) diff --git a/backends/trtllm/cmake/spdlog.cmake b/backends/trtllm/cmake/spdlog.cmake new file mode 100644 index 00000000..c4ee5c97 --- /dev/null +++ b/backends/trtllm/cmake/spdlog.cmake @@ -0,0 +1,17 @@ +set(SPDLOG_USE_FMT ON) +set(SPDLOG_BUILD_SHARED OFF) +set(SPDLOG_FMT_EXTERNAL ON) + +# Define the level at which SPDLOG_ compilation level is defined +if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") + add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG) +else () + add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO) +endif () + +fetchcontent_declare( + spdlog + GIT_REPOSITORY https://github.com/gabime/spdlog.git + GIT_TAG v1.14.1 +) +fetchcontent_makeavailable(spdlog) diff --git a/backends/trtllm/cmake/trtllm.cmake b/backends/trtllm/cmake/trtllm.cmake new file mode 100644 index 00000000..e59ad4cf --- /dev/null +++ b/backends/trtllm/cmake/trtllm.cmake @@ -0,0 +1,42 @@ +set(TRT_INCLUDE_DIR ${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR}) +set(TRT_LIB_DIR ${TGI_TRTLLM_BACKEND_TRT_LIB_DIR}) + +set(USE_CXX11_ABI ON) +set(BUILD_PYT OFF) +set(BUILD_PYBIND OFF) +set(BUILD_MICRO_BENCHMARKS OFF) +set(BUILD_BENCHMARKS OFF) +set(BUILD_TESTS OFF) +set(CMAKE_CUDA_ARCHITECTURES ${TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST}) + +message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}") + +if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") + set(FAST_BUILD ON) + set(NVTX_DISABLE OFF) +else () + set(FAST_BUILD OFF) + set(FAST_MATH ON) + set(NVTX_DISABLE ON) +endif () + +fetchcontent_declare( + trtllm + GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git + GIT_TAG a681853d3803ee5893307e812530b5e7004bb6e1 + GIT_SHALLOW FALSE +) +fetchcontent_makeavailable(trtllm) + +message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}") +execute_process(COMMAND git lfs install WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/") +execute_process(COMMAND git lfs pull WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/") + +# TRTLLM use a JIT based *precompiled* library to generate some specific kernels, we are generating the path to this one here +set(TRTLLM_NVRTC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_nvrtc_wrapper${CMAKE_SHARED_LIBRARY_SUFFIX}" CACHE INTERNAL "nvrtc wrapper library name") +set(TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_NVRTC_LIBRARY_NAME}" + CACHE INTERNAL "nvrtc wrapper library path") + +# The same Executor Static library +set(TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_executor_static${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE INTERNAL "executor_static library name") +set(TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/executor/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME}" CACHE INTERNAL "executor_static library path") diff --git a/backends/trtllm/cmake/utils/detect_cuda_arch.cu b/backends/trtllm/cmake/utils/detect_cuda_arch.cu new file mode 100644 index 00000000..e69de29b diff --git a/backends/trtllm/include/backend.h b/backends/trtllm/include/backend.h new file mode 100644 index 00000000..7990e76b --- /dev/null +++ b/backends/trtllm/include/backend.h @@ -0,0 +1,121 @@ +// +// Created by Morgan Funtowicz on 6/30/24. +// + +#ifndef TGI_TRTLLM_BACKEND_H +#define TGI_TRTLLM_BACKEND_H + +#include +#include +#include +#include + +#include + +#include +#include +#include + +using json = nlohmann::json; +namespace tle = tensorrt_llm::executor; + +namespace huggingface::tgi::backends { + using RequestId = tle::IdType; + using TokenId = tle::TokenIdType; + + /** + * Initialize all the components required by TRTLLM. + * It is required to call this function before attempting to load any engine + */ + void InitializeBackend(); + + /** + * + * @param config TensorRT-LLM configuration object + * @param workerPath Path to the "executorWorker" provided by TensorRT-LLM when using orchestrator mode + * @return + */ + tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath); + + /** + * Get the sampling configuration from the parameters provided by TGI + * @param topK + * @param topP + * @param temperature + * @param repetition_penalty + * @param frequency_penalty + * @param seed + * @return + */ + tle::SamplingConfig GetSamplingConfig( + uint32_t topK, + float_t topP, + float_t temperature, + float_t repetition_penalty, + float_t frequency_penalty, + uint64_t seed + ); + + /** + * + */ + class TensorRtLlmBackend { + private: + const json config; + tle::Executor executor; + + public: + explicit TensorRtLlmBackend( + const std::filesystem::path &engineFolder, + const std::filesystem::path &executorWorker + ); + + /** + * Indicate if the backend is ready to accept incoming request + * @return true if ready, false otherwise + */ + [[nodiscard]] bool IsReady() const; + + /** + * Query the executor for the number of token available for pulling + * @return + */ + [[nodiscard]] size_t NumResponsesReady() const; + + /** + * Submit a new generation task to the executor + * @param tokens + * @param topK + * @param topP + * @param temperature + * @param repetition_penalty + * @param frequency_penalty + * @param seed + * @return Request id related to this generation for reference + */ + [[nodiscard]] RequestId Submit( + const std::vector &tokens, + int32_t topK, + float_t topP, + float_t temperature, + float_t repetition_penalty, + float_t frequency_penalty, + uint64_t seed + ); + + /** + * + * @param requestId The request id to poll the generation results + * @return + */ + std::vector Poll(RequestId requestId); + + /** + * Stop the underlying executor + */ + void Shutdown(); + }; +} + + +#endif //TGI_TRTLLM_BACKEND_H diff --git a/backends/trtllm/include/ffi.h b/backends/trtllm/include/ffi.h new file mode 100644 index 00000000..fe0be9fc --- /dev/null +++ b/backends/trtllm/include/ffi.h @@ -0,0 +1,75 @@ +// +// Created by mfuntowicz on 7/11/24. +// + +#ifndef TGI_TRTLLM_BACKEND_FFI_H +#define TGI_TRTLLM_BACKEND_FFI_H + +#include +#include "backend.h" + +namespace huggingface::tgi::backends { + class TensorRtLlmBackendImpl; +} + +#include "backends/trtllm/src/lib.rs.h" + + +namespace huggingface::tgi::backends { + +// struct GenerationContext; + + class TensorRtLlmBackendImpl : public TensorRtLlmBackend { + public: + /*** + * + * @param engineFolder + * @param executorWorker + */ + TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker); + + /*** + * + * @return + */ + bool IsReady() const; + + /*** + * + * @param tokens + * @param topK + * @param topP + * @param temperature + * @param repetition_penalty + * @param frequency_penalty + * @param seed + * @return + */ + [[nodiscard("returned request id should be used to refer to the request's generation result later on")]] + uint64_t + Submit(rust::Slice tokens, int32_t topK, float_t topP, float_t temperature, + float_t repetition_penalty, float_t frequency_penalty, uint64_t seed); + + /*** + * + * @param requestId + * @param ctx + * @param callback + * @return + */ + size_t StreamTokens( + const RequestId requestId, + huggingface::tgi::backends::GenerationContext *ctx, + rust::Fn callback); + }; + + /*** + * + * @param engineFolder + * @return + */ + std::unique_ptr CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker); +} + +#endif //TGI_TRTLLM_BACKEND_FFI_H diff --git a/backends/trtllm/include/hardware.h b/backends/trtllm/include/hardware.h new file mode 100644 index 00000000..da0bf4f3 --- /dev/null +++ b/backends/trtllm/include/hardware.h @@ -0,0 +1,59 @@ +// +// Created by mfuntowicz on 7/23/24. +// + +#ifndef TGI_TRTLLM_BACKEND_HARDWARE_H +#define TGI_TRTLLM_BACKEND_HARDWARE_H + +#include +#include +#include +#include +#include + +namespace huggingface::hardware::cuda { + +#define AMPERE_SM_MAJOR 8 +#define HOPPER_SM_MAJOR 8 + + /** + * Store information about the version of the CUDA Compute Capabilities detected on the device + */ + struct CudaComputeCapabilities { + int32_t major; + int32_t minor; + + [[nodiscard]] constexpr bool isPostAmpere() const { return major >= AMPERE_SM_MAJOR; } + + [[nodiscard]] constexpr bool isPostHopper() const { return major >= HOPPER_SM_MAJOR; } + }; + + CudaComputeCapabilities GetCudaComputeCapabilities() { + // Get the compute capabilities of the current hardware + nvmlDevice_t device; + CudaComputeCapabilities capabilities{0, 0}; + if (nvmlDeviceGetHandleByIndex_v2(0, &device) == NVML_SUCCESS) { + SPDLOG_DEBUG("Successfully acquired nvmlDevice_t = 0"); + if (nvmlDeviceGetCudaComputeCapability(device, &capabilities.major, &capabilities.minor) == NVML_SUCCESS) { + SPDLOG_INFO("Detected sm_{:d}{:d} compute capabilities", capabilities.major, capabilities.minor); + } + } + + return capabilities; + } + + /** + * Return the number of GPU detected. If no GPU is detected, return size_t::max() + * @return + */ + std::optional GetNumDevices() { + uint32_t numGpus = 0; + if (nvmlDeviceGetCount_v2(&numGpus) == NVML_SUCCESS) { + return std::optional(numGpus); + } else { + return std::nullopt; + } + } +} + +#endif //TGI_TRTLLM_BACKEND_HARDWARE_H diff --git a/backends/trtllm/lib/backend.cpp b/backends/trtllm/lib/backend.cpp new file mode 100644 index 00000000..c066a6d6 --- /dev/null +++ b/backends/trtllm/lib/backend.cpp @@ -0,0 +1,146 @@ +#include + +#include +#include +#include + +#include "backend.h" +#include "hardware.h" + +void huggingface::tgi::backends::InitializeBackend() { + SPDLOG_INFO("Initializing Backend..."); + nvmlInit_v2(); + initTrtLlmPlugins(); + + const auto numGpus = huggingface::hardware::cuda::GetNumDevices(); + if (numGpus.has_value()) { + SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value()); + } else { + SPDLOG_WARN("Failed to detected Nvidia GPU(s) on the system"); + } +} + +[[nodiscard]] +tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) { + tle::ExecutorConfig execConfig(1); + + // Retrieve the compute capabilities to enable some options at runtime + const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities(); + + // Single engine (TP = PP = 1) -> using leader mode (no MPI involved) + if (config["/pretrained_config/mapping/world_size"_json_pointer].get() == 1) { + SPDLOG_INFO("Detected single engine deployment, using leader mode"); + execConfig.setParallelConfig(tle::ParallelConfig( + tle::CommunicationType::kMPI, + tle::CommunicationMode::kLEADER, + std::nullopt, + std::nullopt, + std::nullopt + )); + } else { // Multiple engines -> using orchestrator mode (MPI involved) + SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode"); + execConfig.setParallelConfig(tle::ParallelConfig( + tle::CommunicationType::kMPI, + tle::CommunicationMode::kORCHESTRATOR, + std::nullopt, + std::nullopt, + tle::OrchestratorConfig(true, workerPath, nullptr, true) + )); + } + + // Define some configuration variables + execConfig.setKvCacheConfig(tle::KvCacheConfig(true)); + execConfig.setEnableChunkedContext(computeCapabilities.isPostAmpere()); + return execConfig; +} + +tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig( + uint32_t topK, + float_t topP, + float_t temperature, + float_t repetition_penalty, + float_t frequency_penalty, + uint64_t seed) { + return tle::SamplingConfig( + 1, // TGI only use a single beam + topK, + topP, + std::nullopt, + std::nullopt, + std::nullopt, + seed, + temperature, + temperature, + std::nullopt, + repetition_penalty, + std::nullopt, + frequency_penalty + ); +} + +huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend( + const std::filesystem::path &enginesFolder, + const std::filesystem::path &executorWorker +) : + config(json::parse(std::ifstream(enginesFolder / "config.json"))), + executor( + enginesFolder, + tensorrt_llm::executor::ModelType::kDECODER_ONLY, + GetExecutorConfig(config, executorWorker.string() + )) { + SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref()); +} + +bool huggingface::tgi::backends::TensorRtLlmBackend::IsReady() const { + return executor.canEnqueueRequests(); +} + +[[nodiscard("Returned number of requests needs to be consumed")]] +size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const { + return executor.getNumResponsesReady(); +} + +[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]] +tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit( + const std::vector &tokens, + const int32_t topK, + const float_t topP, + const float_t temperature, + const float_t repetition_penalty, + const float_t frequency_penalty, + const uint64_t seed +) { +#ifdef NDEBUG + SPDLOG_DEBUG( + FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"), + tokens.size(), + executor.getLatestIterationStats().back().numActiveRequests + ); +#else + SPDLOG_DEBUG( + FMT_STRING("Submitting inference [{}] to the executor ({:d} already in-flight)"), + fmt::join(tokens, ", "), + executor.getLatestIterationStats().front().numActiveRequests + ); +#endif + + const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get(); + const auto maxNewTokens = static_cast(std::max(1ul, maxNumTokens - tokens.size())); + + const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed); + const auto output = tle::OutputConfig(true, false, false, true, false); + return executor.enqueueRequest( + tle::Request{tokens, maxNewTokens, true, sampling, output}); +} + +[[nodiscard("Generated tokens result must be used")]] +std::vector huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType requestId) { + SPDLOG_DEBUG(FMT_STRING("Polling status for request {:d}"), requestId); + return executor.awaitResponses(requestId); +} + + +void huggingface::tgi::backends::TensorRtLlmBackend::Shutdown() { + SPDLOG_INFO("Shutting down executor"); + executor.shutdown(); +} diff --git a/backends/trtllm/scripts/install_tensorrt.sh b/backends/trtllm/scripts/install_tensorrt.sh new file mode 100755 index 00000000..e0e2dd17 --- /dev/null +++ b/backends/trtllm/scripts/install_tensorrt.sh @@ -0,0 +1,111 @@ +#!/bin/bash + +set -ex + +TRT_VER="10.2.0.19" +CUDA_VER="12.5" +CUDNN_VER="9.2.1.18-1" +NCCL_VER="2.22.3-1+cuda12.5" +CUBLAS_VER="12.5.3.2-1" +NVRTC_VER="12.5.82-1" + +for i in "$@"; do + case $i in + --TRT_VER=?*) TRT_VER="${i#*=}";; + --CUDA_VER=?*) CUDA_VER="${i#*=}";; + --CUDNN_VER=?*) CUDNN_VER="${i#*=}";; + --NCCL_VER=?*) NCCL_VER="${i#*=}";; + --CUBLAS_VER=?*) CUBLAS_VER="${i#*=}";; + *) ;; + esac + shift +done + +NVCC_VERSION_OUTPUT=$(nvcc --version) +if [[ $(echo $NVCC_VERSION_OUTPUT | grep -oP "\d+\.\d+" | head -n 1) != ${CUDA_VER} ]]; then + echo "The version of pre-installed CUDA is not equal to ${CUDA_VER}." + exit 1 +fi + +install_ubuntu_requirements() { + apt-get update && apt-get install -y --no-install-recommends gnupg2 curl ca-certificates + ARCH=$(uname -m) + if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi + if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi + curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/${ARCH}/cuda-keyring_1.0-1_all.deb + dpkg -i cuda-keyring_1.0-1_all.deb + + apt-get update + if [[ $(apt list --installed | grep libcudnn9) ]]; then + apt-get remove --purge -y --allow-change-held-packages libcudnn9* + fi + if [[ $(apt list --installed | grep libnccl) ]]; then + apt-get remove --purge -y --allow-change-held-packages libnccl* + fi + if [[ $(apt list --installed | grep libcublas) ]]; then + apt-get remove --purge -y --allow-change-held-packages libcublas* + fi + if [[ $(apt list --installed | grep cuda-nvrtc-dev) ]]; then + apt-get remove --purge -y --allow-change-held-packages cuda-nvrtc-dev* + fi + CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g') + apt-get install -y --no-install-recommends libcudnn9-cuda-12=${CUDNN_VER} libcudnn9-dev-cuda-12=${CUDNN_VER} + apt-get install -y --no-install-recommends libnccl2=${NCCL_VER} libnccl-dev=${NCCL_VER} + apt-get install -y --no-install-recommends libcublas-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} libcublas-dev-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} + # NVRTC static library doesn't exist in NGC PyTorch container. + NVRTC_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g') + apt-get install -y --no-install-recommends cuda-nvrtc-dev-${NVRTC_CUDA_VERSION}=${NVRTC_VER} + apt-get clean + rm -rf /var/lib/apt/lists/* +} + +install_centos_requirements() { + CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g') + yum -y update + yum -y install epel-release + yum remove -y libnccl* && yum -y install libnccl-${NCCL_VER} libnccl-devel-${NCCL_VER} + yum remove -y libcublas* && yum -y install libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER} libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER} + yum clean all +} + +install_tensorrt() { + #PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))') + #PARSED_PY_VERSION=$(echo "${PY_VERSION//./}") + TRT_CUDA_VERSION="12.5" + + if [ -z "$RELEASE_URL_TRT" ];then + ARCH=${TRT_TARGETARCH} + if [ -z "$ARCH" ];then ARCH=$(uname -m);fi + if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi + if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi + if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi + if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-22.04" && OS="ubuntu-22.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi + RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.2.0/tars/TensorRT-${TRT_VER}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz + fi + wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar + tar -xf /tmp/TensorRT.tar -C /usr/local/ + mv /usr/local/TensorRT-${TRT_VER} /usr/local/tensorrt + # pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl + rm -rf /tmp/TensorRT.tar +} + +# Install base packages depending on the base OS +ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"') +case "$ID" in + debian) + install_ubuntu_requirements + install_tensorrt + ;; + ubuntu) + install_ubuntu_requirements + install_tensorrt + ;; + centos) + install_centos_requirements + install_tensorrt + ;; + *) + echo "Unable to determine OS..." + exit 1 + ;; +esac diff --git a/backends/trtllm/src/backend.rs b/backends/trtllm/src/backend.rs new file mode 100644 index 00000000..b23aa6c0 --- /dev/null +++ b/backends/trtllm/src/backend.rs @@ -0,0 +1,330 @@ +use std::future::Future; +use std::path::Path; +use std::pin::{pin, Pin}; +use std::str::FromStr; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, OnceLock}; +use std::task::{Context, Poll}; +use std::time::Duration; + +use async_trait::async_trait; +use cxx::UniquePtr; +use log::{error, warn}; +use tokenizers::Tokenizer; +use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; +use tokio::time::{sleep, Instant}; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio_stream::{Stream, StreamExt}; +use tracing::{instrument, span, Level}; + +// use tokio::sync::RwLock; +use parking_lot::RwLock; +use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; +use text_generation_router::validation::ValidationError::UnsupportedModality; +use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError}; +use text_generation_router::{FinishReason, Token}; + +use crate::errors::TensorRtLlmBackendError; +use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl}; + +// Value used to poll the state of the generation stream +static POLLING_INTERVAL_US: OnceLock = OnceLock::new(); + +type InferResult = Result; + +pub(crate) struct Generation { + executor: Arc>>, + done: Arc, +} + +/// Holds the user provided input to be executed along with a channel allowing +/// to bubble up all the generated tokens for that tokens the to end stream. +pub struct GenerationContext { + sender: UnboundedSender>, + tokenizer: Arc, + tokens: Vec, + done: Arc, + queued: Instant, + start: Option, +} + +impl Stream for Generation { + type Item = usize; + + fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { + let interval = POLLING_INTERVAL_US.get_or_init(|| { + u64::from_str(option_env!("TRTLLM_BACKEND_POLLING_INTERVAL_US").unwrap_or("100")) + .expect("Invalid value provided for envvar POLLING_INTERVAL_US") + }); + + if !self.done.load(Ordering::Relaxed) { + let backend = pin!(self.executor.read()); + let status = match backend.poll(ctx) { + Poll::Ready(executor_r) => { + let ready = executor_r.num_responses_ready(); + if ready == 0 { + Poll::Pending + } else { + Poll::Ready(Some(ready)) + } + } + Poll::Pending => Poll::Pending, + }; + + let waker = ctx.waker().clone(); + tokio::spawn(async { + sleep(Duration::from_micros(*interval)).await; + waker.wake(); + }); + + status + } else { + Poll::Ready(None) // end of stream + } + } + + fn size_hint(&self) -> (usize, Option) { + (1, None) + } +} + +unsafe impl Send for TensorRtLlmBackendImpl {} +unsafe impl Sync for TensorRtLlmBackendImpl {} + +/// Implements the logic to execute generation with TensorRT-LLM executor API in background +pub struct TensorRtLlmBackend { + tokenizer: Arc, + + // Backing the backend behind a RwLock to allow concurrent read access to retrieve + // the number of available tokens (read only) in the Generation stream + backend: Arc>>, +} + +impl TensorRtLlmBackend { + pub fn new + Send + 'static, PP: AsRef + Send + 'static>( + tokenizer: Tokenizer, + engine_folder: P, + executor_worker_path: PP, + ) -> Result { + Ok(TensorRtLlmBackend { + tokenizer: Arc::new(tokenizer), + backend: Arc::new(RwLock::new(create_tensorrt_llm_backend( + engine_folder.as_ref().to_str().unwrap(), + executor_worker_path.as_ref().to_str().unwrap(), + ))), + }) + } + + fn validate(request: &ValidGenerateRequest) -> InferResult<&String> { + if request.top_n_tokens > 1 { + return Err(InferError::ValidationError( + ValidationError::TopNTokensDisabled, + )); + } + + // TODO: Is it really needed? How can it be validated before? + if request.parameters.grammar.is_some() { + return Err(InferError::ValidationError(ValidationError::Grammar)); + } + + match request.inputs.len() { + 0 => Err(InferError::ValidationError(ValidationError::EmptyInput)), + 2.. => Err(InferError::GenerationError( + "TensorRT-LLM backend don't support multi-chunk".into(), + )), + 1 => match request.inputs.first().expect("Single item-chunk") { + Chunk::Text(text) => Ok(text), + Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))), + }, + } + } + + fn generate( + &self, + sender: UnboundedSender>, + tokens: Vec, + top_k: u32, + top_p: f32, + temperature: f32, + repetition_penalty: f32, + frequency_penalty: f32, + seed: u64, + ) { + let tokenizer = Arc::clone(&self.tokenizer); + let executor = Arc::clone(&self.backend); + + // Let's push this in async context + tokio::spawn(async move { + // Define the generation state + let mut generation = Generation { + executor: executor.clone(), + done: Arc::new(AtomicBool::new(false)), + }; + + // Define the context over the generation + // TODO(asap): Do we really need so many shared-ownership? + let ctx = Box::new(GenerationContext { + sender: sender.clone(), + tokenizer, + tokens: vec![], + done: Arc::clone(&generation.done), + start: None, + queued: Instant::now(), + }); + + // We are leaking the context on-purpose to avoid the box being dropped while there are + // still computation ongoing + // TODO(asap): Can we achieve the same with an Arc> without the need to go unsafe? + let ctx_ = Box::leak(ctx); + + // Submit the request to the batcher + let request_id = span!(Level::DEBUG, "submit") + .in_scope(|| async { + let mut handle = executor.write().await; + let request_id = handle.pin_mut().submit( + &tokens, + top_k as i32, + top_p, + temperature, + repetition_penalty, + frequency_penalty, + seed, + ); + + request_id + }) + .await; + + while let Some(_) = generation.next().await { + let mut executor_w = executor.write().await; + let executor = executor_w.pin_mut(); + + span!(Level::DEBUG, "decode") + .in_scope(|| async { + unsafe { + executor.stream_tokens( + request_id, + ctx_, + |ctx: *mut GenerationContext, step: GenerationStep| { + let inner_ctx = &mut *ctx; + + // Update the timestamp at which the request started effectively + // Can be a bit off, would need to be before the callback, let's see + inner_ctx.start.get_or_insert(Instant::now()); + inner_ctx.done.store(step.is_final, Ordering::Relaxed); + + // Ensure we are not running into errors + let parcel = if !step.has_error { + // Insert the latest generated token to the tracker + inner_ctx.tokens.push(step.token_id); + + // Decode the token + let text = inner_ctx + .tokenizer + .decode(&[step.token_id], true) + .expect("Failed to decode token"); + + let special = inner_ctx + .tokenizer + .get_added_vocabulary() + .is_special_token(&text); + + // Create the structure holding the token + let token = Token { + id: step.token_id, + text, + logprob: step.log_prob, + special, + }; + + if step.is_final { + let generated_text = inner_ctx + .tokenizer + .decode(&inner_ctx.tokens, true) + .expect("Failed to decode generated_tokens"); + + Ok(InferStreamResponse::End { + token, + top_tokens: vec![], + generated_text: GeneratedText { + text: generated_text, + generated_tokens: inner_ctx.tokens.len() as u32, + finish_reason: FinishReason::EndOfSequenceToken, + seed: None, + }, + start: inner_ctx.start.unwrap_or(Instant::now()), + queued: inner_ctx.queued, + }) + } else { + Ok(InferStreamResponse::Intermediate { + token, + top_tokens: vec![], + }) + } + } else { + error!("Error caught while decoding: {}", &step.error_msg); + Err(InferError::GenerationError(step.error_msg)) + }; + + // Send the parcel to the client + inner_ctx + .sender + .send(parcel) + .expect("Failed to sent msg through the channel"); + }, + ); + } + }) + .await; + } + + // "Properly" free the shared context... + // TODO: clean that piece of sh** asap + unsafe { + let _ = Box::from_raw(ctx_); + } + }); + } +} + +#[async_trait] +impl Backend for TensorRtLlmBackend { + #[instrument(skip_all)] + fn schedule( + &self, + request: ValidGenerateRequest, + ) -> InferResult>> { + // Let's add a few more validation + let input = TensorRtLlmBackend::validate(&request)?; + + // Channel to stream the generated token as they come from the worker thread back to the transport layer + let (sender, receiver) = unbounded_channel(); + + // Unpack parameters + let params = &request.parameters; + + // Preprocess the inputs to send to TRTLLM backend + let encoding = self + .tokenizer + .encode(input.as_str(), true) + .map_err(|e| InferError::GenerationError(e.to_string()))?; + + // Generate the response + self.generate( + sender, + Vec::from(encoding.get_ids()), + params.top_k, + params.top_p, + params.temperature, + params.repetition_penalty, + params.frequency_penalty, + params.seed, + ); + + Ok(UnboundedReceiverStream::new(receiver)) + } + + async fn health(&self, _current_health: bool) -> bool { + true + } +} diff --git a/backends/trtllm/src/errors.rs b/backends/trtllm/src/errors.rs new file mode 100644 index 00000000..a672d2a4 --- /dev/null +++ b/backends/trtllm/src/errors.rs @@ -0,0 +1,15 @@ +use thiserror::Error; + +use text_generation_router::server; + +#[derive(Debug, Error)] +pub enum TensorRtLlmBackendError { + #[error("Tokenizer error: {0}")] + Tokenizer(String), + #[error("Argument validation error: {0}")] + ArgumentValidation(String), + #[error("WebServer error: {0}")] + WebServer(#[from] server::WebServerError), + #[error("Tokio runtime failed to start: {0}")] + Tokio(#[from] std::io::Error), +} diff --git a/backends/trtllm/src/ffi.cpp b/backends/trtllm/src/ffi.cpp new file mode 100644 index 00000000..d6317a68 --- /dev/null +++ b/backends/trtllm/src/ffi.cpp @@ -0,0 +1,84 @@ +// +// Created by mfuntowicz on 6/30/24. +// +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include "backends/trtllm/include/ffi.h" + + +huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl( + const std::string_view &engineFolder, + const std::string_view &executorWorker +) : TensorRtLlmBackend(engineFolder, executorWorker) {} + + +bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const { + return TensorRtLlmBackend::IsReady(); +} + +uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit( + rust::Slice tokens, int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty, + float_t frequency_penalty, uint64_t seed) { + + // This will copy all the items from the initial slice + std::vector tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end())); + return TensorRtLlmBackend::Submit( + std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed); +} + +size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens( + const uint64_t requestId, + huggingface::tgi::backends::GenerationContext *ctx, + rust::Fn callback) { + + size_t numTokens = 0; + for (const auto &item: Poll(requestId)) { + GenerationStep step; + if (!item.hasError()) { + SPDLOG_DEBUG("\tStreamTokens -> Decoding token..."); + const auto decoded = item.getResult(); + + const auto token = decoded.outputTokenIds[0][0]; + const auto isFinal = decoded.isFinal; + const auto logProb = decoded.logProbs.value()[0][0]; + + ++numTokens; + + SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal); + step = huggingface::tgi::backends::GenerationStep{ + static_cast(token), logProb, isFinal, false, std::move(std::string()) + }; + SPDLOG_DEBUG("\tStreamTokens -> Post callback"); + } else { + // TODO : Return rest::Result with error + const auto what = item.getErrorMsg(); + SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what); + step = huggingface::tgi::backends::GenerationStep{ + std::numeric_limits::max(), 0.0, true, true, std::move(what) + }; + } + + callback(std::move(ctx), std::move(step)); + } + + return numTokens; +} + +std::unique_ptr +huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) { + // Unconditionally call this to initialize and discover TRTLLM plugins + InitializeBackend(); + + const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end()); + const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end()); + return std::make_unique(std::move(enginePath), std::move(executorPath)); +} diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs new file mode 100644 index 00000000..1a804f88 --- /dev/null +++ b/backends/trtllm/src/lib.rs @@ -0,0 +1,78 @@ +pub use backend::{GenerationContext, TensorRtLlmBackend}; + +mod backend; +pub mod errors; + +#[cxx::bridge(namespace = "huggingface::tgi::backends")] +mod ffi { + + /// Struct used as shared type between rust and C++ to represent the result + /// of a single decoding iteration + pub struct GenerationStep { + token_id: u32, + log_prob: f32, + is_final: bool, + has_error: bool, + error_msg: String, + } + + extern "Rust" { + type GenerationContext; + } + + unsafe extern "C++" { + include!("backends/trtllm/src/ffi.cpp"); + + /// Represent an instance of the underlying TensorRT-LLM backend + type TensorRtLlmBackendImpl; + + /// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend + /// + /// # Arguments + /// + /// * `engine_folder`: Path to the folder containing all the TRTLLM engines + /// * `executor_worker`: Path to the TRTLLM executor worker + /// + /// returns: + /// + /// # Examples + /// + /// ``` + /// + /// ``` + #[rust_name = "create_tensorrt_llm_backend"] + fn CreateTensorRtLlmBackend( + engine_folder: &str, + executor_worker: &str, + ) -> UniquePtr; + + // #[rust_name = "is_ready"] + // fn IsReady(self: &TensorRtLlmBackendImpl) -> bool; + + #[rust_name = "num_responses_ready"] + fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize; + + #[rust_name = "submit"] + fn Submit( + self: Pin<&mut TensorRtLlmBackendImpl>, + tokens: &[u32], + top_k: i32, + top_p: f32, + temperature: f32, + repetition_penalty: f32, + frequency_penalty: f32, + seed: u64, + ) -> u64; + + #[rust_name = "stream_tokens"] + unsafe fn StreamTokens( + self: Pin<&mut TensorRtLlmBackendImpl>, + request_id: u64, + ctx: *mut GenerationContext, + cb: unsafe fn(*mut GenerationContext, GenerationStep), + ) -> usize; + + // #[rust_name = "shutdown"] + // fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>); + } +} diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs new file mode 100644 index 00000000..e0ba46c7 --- /dev/null +++ b/backends/trtllm/src/main.rs @@ -0,0 +1,166 @@ +use clap::Parser; +use std::collections::HashMap; +use std::path::PathBuf; +use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; +use text_generation_backends_trtllm::TensorRtLlmBackend; +use text_generation_router::server; +use tokenizers::{FromPretrainedParameters, Tokenizer}; + +/// App Configuration +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + #[clap(default_value = "128", long, env)] + max_concurrent_requests: usize, + #[clap(default_value = "2", long, env)] + max_best_of: usize, + #[clap(default_value = "4", long, env)] + max_stop_sequences: usize, + #[clap(default_value = "5", long, env)] + max_top_n_tokens: u32, + #[clap(default_value = "1024", long, env)] + max_input_tokens: usize, + #[clap(default_value = "2048", long, env)] + max_total_tokens: usize, + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + #[clap(default_value = "3000", long, short, env)] + port: u16, + #[clap(long, env, required = true)] + tokenizer_name: String, + #[clap(long, env)] + tokenizer_config_path: Option, + #[clap(long, env)] + revision: Option, + #[clap(long, env)] + model_id: String, + #[clap(default_value = "2", long, env)] + validation_workers: usize, + #[clap(long, env)] + json_output: bool, + #[clap(long, env)] + otlp_endpoint: Option, + #[clap(default_value = "text-generation-inference.router", long, env)] + otlp_service_name: String, + #[clap(long, env)] + cors_allow_origin: Option>, + #[clap(long, env, default_value_t = false)] + messages_api_enabled: bool, + #[clap(default_value = "4", long, env)] + max_client_batch_size: usize, + #[clap(long, env)] + auth_token: Option, + #[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")] + executor_worker: PathBuf, +} + +#[tokio::main] +async fn main() -> Result<(), TensorRtLlmBackendError> { + // Get args + let args = Args::parse(); + // Pattern match configuration + let Args { + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + max_batch_prefill_tokens, + max_batch_total_tokens, + hostname, + port, + tokenizer_name, + tokenizer_config_path, + revision, + model_id, + validation_workers, + json_output, + otlp_endpoint, + otlp_service_name, + cors_allow_origin, + messages_api_enabled, + max_client_batch_size, + auth_token, + executor_worker, + } = args; + + // Launch Tokio runtime + text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output); + + // Validate args + if max_input_tokens >= max_total_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation( + "`max_input_tokens` must be < `max_total_tokens`".to_string(), + )); + } + if max_input_tokens as u32 > max_batch_prefill_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); + } + + if validation_workers == 0 { + return Err(TensorRtLlmBackendError::ArgumentValidation( + "`validation_workers` must be > 0".to_string(), + )); + } + + if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { + if max_batch_prefill_tokens > *max_batch_total_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); + } + if max_total_tokens as u32 > *max_batch_total_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); + } + } + + if !executor_worker.exists() { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!( + "`executor_work` specified path doesn't exists: {}", + executor_worker.display() + ))); + } + + // Run server + let tokenizer = Tokenizer::from_pretrained( + tokenizer_name.clone(), + Some(FromPretrainedParameters { + revision: revision.clone().unwrap_or(String::from("main")), + user_agent: HashMap::new(), + auth_token, + }), + ) + .map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?; + + let backend = TensorRtLlmBackend::new(tokenizer, model_id, executor_worker)?; + server::run( + backend, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + validation_workers, + None, + tokenizer_name, + tokenizer_config_path, + revision, + hostname, + port, + cors_allow_origin, + false, + None, + None, + messages_api_enabled, + true, + max_client_batch_size, + false, + false, + ) + .await?; + Ok(()) +} diff --git a/backends/trtllm/tests/infer_test.cpp b/backends/trtllm/tests/infer_test.cpp new file mode 100644 index 00000000..8520065a --- /dev/null +++ b/backends/trtllm/tests/infer_test.cpp @@ -0,0 +1,14 @@ +// +// Created by mfuntowicz on 7/2/24. +// +#include +#include +#include "../include/backend.h" + +TEST_CASE("Load TRTLLM Engine on the TGI Backend", "[trtllm][engine][load]") { + const auto engines = std::filesystem::path("/home/mfuntowicz/.cache/huggingface/assets/trtllm/0.11.0.dev2024062500/meta-llama--Meta-Llama-3-8B-Instruct/4090/engines/"); + const auto executor = std::filesystem::path("/home/mfuntowicz/Workspace/text-generation-inference/backends/trtllm/cmake-build-debug/cmake-build-debug/_deps/trtllm-src/cpp/tensorrt_llm/executor_worker/executorWorker"); + + spdlog::info("Loading config from: {}", absolute(engines).string()); + huggingface::tgi::backends::TensorRtLlmBackend backend(engines, executor); +} diff --git a/backends/v2/Cargo.toml b/backends/v2/Cargo.toml new file mode 100644 index 00000000..4d32474e --- /dev/null +++ b/backends/v2/Cargo.toml @@ -0,0 +1,75 @@ +[package] +name = "text-generation-router-v2" +description = "Text Generation Webserver" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[lib] +path = "src/lib.rs" + +[[bin]] +name = "text-generation-router-v2" +path = "src/main.rs" + +[dependencies] +async-trait = "0.1.74" +async-stream = "0.3.5" +axum = { version = "0.7", features = ["json"] } +axum-tracing-opentelemetry = "0.16" +text-generation-router = { path = "../../router" } +clap = { version = "4.4.5", features = ["derive", "env"] } +grpc-metadata = { path = "../grpc-metadata" } +futures = "0.3.28" +hf-hub = { workspace = true } +jsonschema = { version = "0.17.1", features = ["draft202012"] } +metrics = { workspace = true } +metrics-exporter-prometheus = { workspace = true } +nohash-hasher = "0.2.0" +opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } +opentelemetry-otlp = "0.13.0" +rand = "0.8.5" +reqwest = { version = "0.11.20", features = [] } +serde = "1.0.188" +serde_json = "1.0.107" +slotmap = "1.0.7" +thiserror = "1.0.48" +tokenizers = { workspace = true } +tokio = { version = "1.32.0", features = [ + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", +] } +tokio-stream = "0.1.14" +tower-http = { version = "0.5.1", features = ["cors"] } +tracing = "0.1.37" +tracing-opentelemetry = "0.21.0" +tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } +utoipa = { version = "4.2.0", features = ["axum_extras"] } +utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } +init-tracing-opentelemetry = { version = "0.14.1", features = [ + "opentelemetry-otlp", +] } +minijinja = { workspace = true } +minijinja-contrib = { workspace = true } +futures-util = "0.3.30" +regex = "1.10.3" +once_cell = "1.19.0" +image = "0.25.1" +base64 = { workspace = true } +prost = "^0.12" +tonic = "^0.10" +tower = "^0.4" + +[build-dependencies] +tonic-build = "0.10.1" +prost-build = "0.12.1" + +[features] +default = ["ngrok"] +ngrok = ["text-generation-router/ngrok"] +google = ["text-generation-router/google"] +kserve = ["text-generation-router/kserve"] diff --git a/router/client/build.rs b/backends/v2/build.rs similarity index 76% rename from router/client/build.rs rename to backends/v2/build.rs index 497be545..f1d85dc7 100644 --- a/router/client/build.rs +++ b/backends/v2/build.rs @@ -1,16 +1,16 @@ use std::fs; fn main() -> Result<(), Box> { - println!("cargo:rerun-if-changed=../../proto/generate.proto"); - fs::create_dir("src/pb").unwrap_or(()); + println!("cargo:rerun-if-changed=../../proto/"); + fs::create_dir_all("src/client/pb").unwrap_or(()); let mut config = prost_build::Config::new(); config.protoc_arg("--experimental_allow_proto3_optional"); tonic_build::configure() .build_client(true) .build_server(false) - .out_dir("src/pb") + .out_dir("src/client/pb") .include_file("mod.rs") .compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"]) .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); diff --git a/backends/v2/src/backend.rs b/backends/v2/src/backend.rs new file mode 100644 index 00000000..1ab582eb --- /dev/null +++ b/backends/v2/src/backend.rs @@ -0,0 +1,517 @@ +use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient}; +/// Batching and inference logic +use crate::queue::{Entry, Queue}; +use async_trait::async_trait; +use nohash_hasher::IntMap; +use std::sync::Arc; +use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; +use text_generation_router::validation::ValidGenerateRequest; +use text_generation_router::{Attention, FinishReason, PrefillToken, Token}; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::{mpsc, Notify}; +use tokio::time::Instant; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::{info_span, instrument, Instrument, Span}; + +pub struct BackendV2 { + /// Request queue + queue: Queue, + /// Notify batcher on queue appends + batching_task_notifier: Arc, + /// Client clone, used for health checks to skip the queue + client: ShardedClient, +} + +impl BackendV2 { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + client: ShardedClient, + waiting_served_ratio: f32, + max_input_tokens: u32, + max_total_tokens: u32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: u32, + max_waiting_tokens: usize, + max_batch_size: Option, + requires_padding: bool, + window_size: Option, + speculate: u32, + ) -> Self { + // Infer shared state + let attention = if let Ok(attention) = std::env::var("ATTENTION") { + attention + .parse() + .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) + } else { + Attention::Paged + }; + let block_size = if attention == Attention::FlashDecoding { + 256 + } else { + 16 + }; + + let queue = Queue::new( + requires_padding, + block_size, + window_size, + speculate, + max_input_tokens, + max_total_tokens, + ); + + let batching_task_notifier = Arc::new(Notify::new()); + + // Spawn batching background task that contains all the inference logic + tokio::spawn(batching_task( + client.clone(), + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + queue.clone(), + batching_task_notifier.clone(), + )); + + Self { + queue, + batching_task_notifier, + client, + } + } +} + +#[async_trait] +impl Backend for BackendV2 { + #[instrument(skip_all)] + fn schedule( + &self, + request: ValidGenerateRequest, + ) -> Result>, InferError> { + // MPSC channel to communicate with the background batching task + let (response_tx, response_rx) = mpsc::unbounded_channel(); + + // Append the request to the queue + self.queue.append(Entry { + request, + response_tx, + span: Span::current(), + temp_span: None, + queue_time: Instant::now(), + batch_time: None, + }); + + // Notify the background task that we have a new entry in the queue that needs + // to be batched + self.batching_task_notifier.notify_one(); + + // Return stream + Ok(UnboundedReceiverStream::new(response_rx)) + } + + async fn health(&self, current_health: bool) -> bool { + if current_health { + // Generation is healthy, we only check that the shards can allocate on device + self.client.device_health().await + } else { + self.client.model_health().await + } + .is_ok() + } +} + +/// Batching logic +/// Will be launched in a background Tokio task +/// +/// Batches requests and sends them to the inference server +#[allow(clippy::too_many_arguments)] +pub(crate) async fn batching_task( + mut client: ShardedClient, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: u32, + max_waiting_tokens: usize, + max_batch_size: Option, + queue: Queue, + notifier: Arc, +) { + // Infinite loop + loop { + // Wait for a notification from the Infer struct + notifier.notified().await; + + // Get the next batch from the queue + // This batch might be smaller than the maximum batch size if there are not enough requests + // waiting in the queue + while let Some((mut entries, batch, span)) = queue + .next_batch( + None, + max_batch_size, + max_batch_prefill_tokens, + max_batch_total_tokens, + ) + .await + { + let mut cached_batch = prefill(&mut client, batch, &mut entries) + .instrument(span) + .await; + let mut waiting_tokens = 1; + + // We loop until we do not receive any cached batch from the inference server (== until + // all requests have met their stopping criteria) + while let Some(batch) = cached_batch { + // Get current batch info + let batch_size = batch.size; + let batch_max_tokens = batch.max_tokens; + let mut batches = vec![batch]; + metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); + metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); + + let min_size = if waiting_tokens >= max_waiting_tokens { + // If we didn't onboard any new requests since >= max_waiting_tokens, we try + // to add a new batch even though its size might be small + None + } else { + // Minimum batch size + Some((batch_size as f32 * waiting_served_ratio).floor() as usize) + }; + + let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); + let max_size = + max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); + // Try to get a new batch + if let Some((mut new_entries, new_batch, span)) = queue + .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) + .await + { + // Tracking metrics + if min_size.is_some() { + metrics::counter!("tgi_batch_concat", "reason" => "backpressure") + .increment(1); + } else { + metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") + .increment(1); + } + + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to add the info that this entry is waiting + // because a new batch is being computed + let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); + // Add relationships + span.follows_from(&entry_waiting_span); + entry_waiting_span.follows_from(&span); + // Update entry + entry.temp_span = Some(entry_waiting_span); + }); + + // Generate one token for this new batch to have the attention past in cache + let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) + .instrument(span) + .await; + // Reset waiting counter + waiting_tokens = 1; + // Extend current batch with the new batch + if let Some(new_cached_batch) = new_cached_batch { + entries.extend(new_entries); + batches.push(new_cached_batch); + } + } + + // Create span for this batch to add context to inference calls + let next_batch_size = entries.len(); + let next_batch_span = + info_span!(parent: None, "batch", batch_size = next_batch_size); + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to link the batch back to this entry + let entry_batch_span = info_span!(parent: &entry.span, "infer"); + // Add relationships + next_batch_span.follows_from(&entry_batch_span); + entry_batch_span.follows_from(&next_batch_span); + // Update entry + entry.temp_span = Some(entry_batch_span); + }); + + cached_batch = decode(&mut client, batches, &mut entries) + .instrument(next_batch_span) + .await; + waiting_tokens += 1; + } + metrics::gauge!("tgi_batch_current_size").set(0.0); + metrics::gauge!("tgi_batch_current_max_tokens").set(0.0); + } + } +} + +#[instrument(skip_all)] +async fn prefill( + client: &mut ShardedClient, + batch: Batch, + entries: &mut IntMap, +) -> Option { + let start_time = Instant::now(); + let batch_id = batch.id; + metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); + + match client.prefill(batch).await { + Ok((generations, next_batch, timings)) => { + let start_filtering_time = Instant::now(); + // Send generated tokens and filter stopped entries + filter_send_generations(generations, entries); + + // Filter next batch and remove requests that were stopped + let next_batch = filter_batch(client, next_batch, entries).await; + + metrics::histogram!("tgi_batch_forward_duration","method" => "prefill") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration","method" => "prefill") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + let _ = client.clear_cache(Some(batch_id)).await; + send_errors(err, entries); + metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1); + None + } + } +} + +#[instrument(skip_all)] +async fn decode( + client: &mut ShardedClient, + batches: Vec, + entries: &mut IntMap, +) -> Option { + let start_time = Instant::now(); + let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); + metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); + + match client.decode(batches).await { + Ok((generations, next_batch, timings)) => { + let start_filtering_time = Instant::now(); + // Send generated tokens and filter stopped entries + filter_send_generations(generations, entries); + + // Filter next batch and remove requests that were stopped + let next_batch = filter_batch(client, next_batch, entries).await; + + if let Some(concat_duration) = timings.concat { + metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") + .record(concat_duration.as_secs_f64()); + } + metrics::histogram!("tgi_batch_forward_duration", "method" => "decode") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "decode") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "decode") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "decode") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + for id in batch_ids { + let _ = client.clear_cache(Some(id)).await; + } + send_errors(err, entries); + metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1); + None + } + } +} + +/// Filter a `batch` and remove all requests not present in `entries` +#[instrument(skip_all)] +async fn filter_batch( + client: &mut ShardedClient, + next_batch: Option, + entries: &IntMap, +) -> Option { + let mut batch = next_batch?; + + // No need to filter + if batch.size as usize == entries.len() { + return Some(batch); + } + + let id = batch.id; + + // Retain only requests that are still in entries + batch.request_ids.retain(|id| entries.contains_key(id)); + + if batch.request_ids.is_empty() { + // All requests have been filtered out + // Next batch is now empty + // Clear it from the Python shards cache + // We unwrap here as we need to panic since we cannot recover if this method fails + client.clear_cache(Some(id)).await.unwrap(); + None + } else { + // Filter Python shard cache + // We unwrap here as we need to panic since we cannot recover if this method fails + client.filter_batch(id, batch.request_ids).await.unwrap() + } +} + +/// Send one or multiple `InferStreamResponse` to Infer for all `entries` +/// and filter entries +#[instrument(skip_all)] +fn filter_send_generations(generations: Vec, entries: &mut IntMap) { + generations.into_iter().for_each(|generation| { + let id = generation.request_id; + // Get entry + // We can `expect` here as the request id should always be in the entries + let entry = entries + .get(&id) + .expect("ID not found in entries. This is a bug."); + + // Create and enter a span to link this function back to the entry + let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); + // Send generation responses back to the infer task + // If the receive an error from the Flume channel, it means that the client dropped the + // request and we need to stop generating hence why we unwrap_or(true) + let stopped = send_responses(generation, entry).inspect_err(|_err| { + tracing::error!("Entry response channel error."); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); + }).unwrap_or(true); + if stopped { + entries.remove(&id).expect("ID not found in entries. This is a bug."); + } + }); +} + +/// Send responses through the `entry` response channel +fn send_responses( + generation: Generation, + entry: &Entry, +) -> Result>>> { + // Return directly if the channel is disconnected + if entry.response_tx.is_closed() { + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); + return Ok(true); + } + + let mut stopped = false; + + if let Some(prefill_tokens) = generation.prefill_tokens { + // Create Token objects + // We do that here instead of in the Python code as Rust for loops are faster + let prefill_tokens = prefill_tokens + .ids + .into_iter() + .zip(prefill_tokens.logprobs) + .zip(prefill_tokens.texts) + .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) + .collect(); + + // Send message + entry + .response_tx + .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; + } + + // Create last Token + let tokens_ = generation.tokens.expect("Non empty tokens in generation"); + let n = tokens_.ids.len(); + metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64); + let mut iterator = tokens_ + .ids + .into_iter() + .zip(tokens_.logprobs) + .zip(tokens_.texts) + .zip(tokens_.is_special) + .enumerate() + .peekable(); + while let Some((i, (((id, logprob), text), special))) = iterator.next() { + let token = Token { + id, + text, + logprob, + special, + }; + let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) { + top_tokens_ + .ids + .iter() + .zip(top_tokens_.logprobs.iter()) + .zip(top_tokens_.texts.iter()) + .zip(top_tokens_.is_special.iter()) + .map(|(((&id, &logprob), text), &special)| Token { + id, + text: text.to_string(), + logprob, + special, + }) + .collect() + } else { + vec![] + }; + match (&generation.generated_text, iterator.peek()) { + (Some(generated_text), None) => { + // Generation has ended + stopped = true; + // Send message + entry.response_tx.send(Ok(InferStreamResponse::End { + token, + top_tokens, + generated_text: GeneratedText::from(generated_text.clone()), + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + }))?; + } + _ => { + // Send message + entry + .response_tx + .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; + } + } + } + + Ok(stopped) +} + +/// Send errors to Infer for all `entries` +#[instrument(skip_all)] +fn send_errors(error: ClientError, entries: &mut IntMap) { + entries.drain().for_each(|(_, entry)| { + // Create and enter a span to link this function back to the entry + let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); + let err = InferError::GenerationError(error.to_string()); + metrics::counter!("tgi_request_failure", "err" => "generation").increment(1); + tracing::error!("{err}"); + + // unwrap_or is valid here as we don't care if the receiver is gone. + entry + .response_tx + .send(Err(err)) + .unwrap_or(()); + }); +} + +impl From for GeneratedText { + fn from(value: crate::client::GeneratedText) -> Self { + let v2_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap(); + let finish_reason = match v2_finish_reason { + crate::client::FinishReason::Length => FinishReason::Length, + crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken, + crate::client::FinishReason::StopSequence => FinishReason::StopSequence, + }; + + Self { + text: value.text, + generated_tokens: value.generated_tokens, + finish_reason, + seed: value.seed, + } + } +} diff --git a/backends/v2/src/client/grpc_client.rs b/backends/v2/src/client/grpc_client.rs new file mode 100644 index 00000000..a56b8a54 --- /dev/null +++ b/backends/v2/src/client/grpc_client.rs @@ -0,0 +1,259 @@ +/// Single shard Client +use crate::client::pb; +use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64}; +use grpc_metadata::InjectTelemetryContext; +use pb::generate::v2::text_generation_service_client::TextGenerationServiceClient; +use pb::generate::v2::*; +use std::cmp::min; +use std::time::Duration; +use tonic::transport::{Channel, Uri}; +use tracing::instrument; + +/// Text Generation Inference gRPC client +#[derive(Debug, Clone)] +pub struct Client { + stub: TextGenerationServiceClient, +} + +impl Client { + /// Returns a client connected to the given url + #[allow(dead_code)] + pub async fn connect(uri: Uri) -> Result { + let channel = Channel::builder(uri).connect().await?; + + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) + } + + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { + let channel = Channel::from_shared("http://[::]:50051".to_string()) + .unwrap() + .connect_with_connector(tower::service_fn(move |_: Uri| { + tokio::net::UnixStream::connect(path.clone()) + })) + .await?; + + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) + } + + /// Returns a list of uris or unix sockets of all shards + #[instrument(skip(self))] + pub async fn service_discovery(&mut self) -> Result> { + let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context(); + let response = self.stub.service_discovery(request).await.map_err(|_| { + ClientError::Connection("Server does not support v2 interface".to_string()) + })?; + let urls = response + .into_inner() + .urls + .into_iter() + // Remove unix socket prefix + .map(|url| match url.strip_prefix("unix://") { + None => url, + Some(stripped_url) => stripped_url.to_string(), + }) + .collect(); + Ok(urls) + } + + /// Get model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let request = tonic::Request::new(InfoRequest {}).inject_context(); + let response = self.stub.info(request).await?.into_inner(); + Ok(response) + } + + /// Get model health + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let request = tonic::Request::new(HealthRequest {}).inject_context(); + let response = self.stub.health(request).await?.into_inner(); + Ok(response) + } + + /// Clear the past generations cache + #[instrument(skip(self))] + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { + let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context(); + self.stub.clear_cache(request).await?; + Ok(()) + } + + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + request_ids: Vec, + ) -> Result> { + let request = tonic::Request::new(FilterBatchRequest { + batch_id, + request_ids, + }) + .inject_context(); + let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); + Ok(filtered_batch.batch) + } + + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip_all)] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + max_batch_total_tokens: u32, + max_batch_size: Option, + ) -> Result> { + let mut n_tokens = 0; + let mut requests = Vec::new(); + // Create requests + while n_tokens < max_prefill_tokens { + let truncate = min(max_input_length, max_prefill_tokens - n_tokens); + + let mut inputs = String::new(); + inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); + if n_tokens == 0 { + // 1 request is enough to test vision heads. + // Sending images on other queries messes up easily with truncation. + inputs.push_str(&format!( + "![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})", + )); + } + + requests.push(Request { + id: 0, + inputs, + // We truncate the input on the server side to be sure that it has the correct size + truncate, + // Set sampling parameters to also take these ops into account in the max memory + parameters: Some(NextTokenChooserParameters { + temperature: 0.9, + top_k: 10, + top_p: 0.9, + typical_p: 0.9, + do_sample: false, + seed: 0, + repetition_penalty: 1.2, + frequency_penalty: 0.1, + watermark: true, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: max_total_tokens - truncate, + stop_sequences: vec![], + ignore_eos_token: true, + }), + prefill_logprobs: true, + top_n_tokens: 20, + }); + n_tokens += max_input_length; + + // Check max_batch_size + if Some(requests.len()) == max_batch_size { + break; + } + } + + let batch = Batch { + id: 0, + size: requests.len() as u32, + requests, + max_tokens: 0, + }; + + let request = tonic::Request::new(WarmupRequest { + batch: Some(batch), + max_input_length, + max_prefill_tokens, + max_total_tokens, + max_batch_total_tokens, + }) + .inject_context(); + let response = self.stub.warmup(request).await?.into_inner(); + Ok(response.max_supported_total_tokens) + } + + /// Generate one token for each request in the given batch + /// + /// Returns Generation for each request in batch + /// and the next cached batch + #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))] + pub async fn prefill( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option, PrefillTimings)> { + let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); + let response = self.stub.prefill(request).await?.into_inner(); + Ok(( + response.generations, + response.batch, + PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns), + )) + } + + /// Generate one token for each request in the given cached batches + /// + /// Returns Generation for each request in batches + /// and the next cached batch + #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::()))] + pub async fn decode( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option, DecodeTimings)> { + let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); + let response = self.stub.decode(request).await?.into_inner(); + Ok(( + response.generations, + response.batch, + DecodeTimings::new( + response.concat_ns, + response.forward_ns, + response.decode_ns, + response.total_ns, + ), + )) + } +} + +pub struct PrefillTimings { + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl PrefillTimings { + fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } + } +} + +pub struct DecodeTimings { + pub concat: Option, + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl DecodeTimings { + fn new(concat_ns: Option, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + concat: concat_ns.map(Duration::from_nanos), + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } + } +} diff --git a/backends/v2/src/client/mod.rs b/backends/v2/src/client/mod.rs new file mode 100644 index 00000000..fa9d4406 --- /dev/null +++ b/backends/v2/src/client/mod.rs @@ -0,0 +1,68 @@ +//! Text Generation gRPC client library + +use async_trait::async_trait; +use thiserror::Error; +use tonic::transport; +use tonic::Status; + +#[allow(clippy::derive_partial_eq_without_eq)] +mod pb; + +mod grpc_client; +mod sharded_client; + +pub use grpc_client::Client; +pub use pb::generate::v2::{ + Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, HealthResponse, + InfoResponse, NextTokenChooserParameters, Request, StoppingCriteriaParameters, +}; +pub use sharded_client::ShardedClient; + +#[async_trait] +pub trait Health { + /// Check if a generate server is healthy by asking it to allocate a tensor on device + async fn device_health(&self) -> Result<()>; + + /// Check if a generate server is healthy by doing a forward pass. + /// EXPENSIVE + async fn model_health(&self) -> Result<()>; +} + +#[derive(Debug)] +pub struct ShardInfo { + pub requires_padding: bool, + pub dtype: String, + pub device_type: String, + pub window_size: Option, + pub speculate: u32, +} + +#[derive(Error, Debug, Clone)] +pub enum ClientError { + #[error("Could not connect to Text Generation server: {0}")] + Connection(String), + #[error("Server error: {0}")] + Generation(String), + #[error("Sharded results are empty")] + EmptyResults, +} + +impl From for ClientError { + fn from(err: Status) -> Self { + let err = Self::Generation(err.message().to_string()); + tracing::error!("{err}"); + err + } +} + +impl From for ClientError { + fn from(err: transport::Error) -> Self { + let err = Self::Connection(err.to_string()); + tracing::error!("{err}"); + err + } +} + +static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; + +pub type Result = std::result::Result; diff --git a/backends/v2/src/client/sharded_client.rs b/backends/v2/src/client/sharded_client.rs new file mode 100644 index 00000000..238bd773 --- /dev/null +++ b/backends/v2/src/client/sharded_client.rs @@ -0,0 +1,254 @@ +/// Multi shard Client +use crate::client::{ClientError, Result}; +use crate::client::{Health, ShardInfo}; + +use crate::client::grpc_client::{DecodeTimings, PrefillTimings}; +use crate::client::InfoResponse; +use crate::client::{ + Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, +}; +use async_trait::async_trait; +use futures::future::join_all; +use tonic::transport::Uri; +use tracing::instrument; + +#[derive(Debug, Clone)] +/// Text Generation Inference gRPC multi client +pub struct ShardedClient { + clients: Vec, +} + +impl ShardedClient { + fn new(clients: Vec) -> Self { + Self { clients } + } + + /// Create a new ShardedClient from a master client. The master client will communicate with + /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method. + async fn from_master_client(mut master_client: Client) -> Result { + // Get all uris/unix sockets from the master client + let uris = master_client.service_discovery().await?; + let futures = uris.into_iter().map(Client::connect_uds); + let clients: Result> = join_all(futures).await.into_iter().collect(); + Ok(Self::new(clients?)) + } + + /// Returns a client connected to the given uri + #[allow(dead_code)] + pub async fn connect(uri: Uri) -> Result { + let master_client = Client::connect(uri).await?; + Self::from_master_client(master_client).await + } + + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { + let master_client = Client::connect_uds(path).await?; + Self::from_master_client(master_client).await + } + + /// Get the model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.info()) + .collect(); + join_all(futures).await.pop().unwrap().map(ShardInfo::from) + } + + /// GRPC health check + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.health()) + .collect(); + join_all(futures).await.pop().unwrap() + } + + /// Clear the past generations cache + #[instrument(skip(self))] + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.clear_cache(batch_id)) + .collect(); + join_all(futures).await.into_iter().collect() + } + + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + request_ids: Vec, + ) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone()))) + .collect(); + // all shards return the same message + join_all(futures).await.pop().unwrap() + } + + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip(self))] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + max_batch_total_tokens: u32, + max_batch_size: Option, + ) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| { + Box::pin(client.warmup( + max_input_length, + max_prefill_tokens, + max_total_tokens, + max_batch_total_tokens, + max_batch_size, + )) + }) + .collect(); + // Take the minimum value + let results = join_all(futures) + .await + .into_iter() + .collect::>>>()?; + Ok(results.into_iter().flatten().min()) + } + + /// Generate one token for each request in the given batch + /// + /// Returns Generation for each request in batch + /// and the next cached batch + #[instrument(skip_all, fields(id = & batch.id, size = & batch.size))] + pub async fn prefill( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option, PrefillTimings)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.prefill(batch.clone()))) + .collect(); + #[allow(clippy::type_complexity)] + let results: Result, Option, PrefillTimings)>> = + join_all(futures).await.into_iter().collect(); + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) + } + + /// Generate one token for each request in the given cached batches + /// + /// Returns Generation for each request in batches + /// and the next cached batch + #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))] + pub async fn decode( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option, DecodeTimings)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.decode(batches.clone()))) + .collect(); + #[allow(clippy::type_complexity)] + let results: Result, Option, DecodeTimings)>> = + join_all(futures).await.into_iter().collect(); + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) + } +} + +impl From for ShardInfo { + fn from(value: InfoResponse) -> Self { + Self { + requires_padding: value.requires_padding, + dtype: value.dtype, + device_type: value.device_type, + window_size: value.window_size, + speculate: value.speculate, + } + } +} + +#[async_trait] +impl Health for ShardedClient { + async fn device_health(&self) -> Result<()> { + self.clone().health().await?; + Ok(()) + } + + async fn model_health(&self) -> Result<()> { + // Dummy batch of 1 token and 1 generated token + let liveness_request = Request { + id: u64::MAX, + inputs: "liveness".to_string(), + truncate: 10, + prefill_logprobs: false, + parameters: Some(NextTokenChooserParameters { + temperature: 1.0, + top_k: 0, + top_p: 1.0, + typical_p: 1.0, + do_sample: false, + seed: 0, + repetition_penalty: 1.0, + frequency_penalty: 0.0, + watermark: false, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: 1, + stop_sequences: vec![], + ignore_eos_token: false, + }), + top_n_tokens: 0, + }; + let batch = Batch { + id: u64::MAX, + requests: vec![liveness_request], + size: 1, + max_tokens: 2, + }; + self.clone().prefill(batch).await?; + Ok(()) + } +} diff --git a/backends/v2/src/lib.rs b/backends/v2/src/lib.rs new file mode 100644 index 00000000..90f03230 --- /dev/null +++ b/backends/v2/src/lib.rs @@ -0,0 +1,144 @@ +mod backend; +mod client; +mod queue; + +use crate::client::{ClientError, ShardedClient}; +pub(crate) use backend::BackendV2; +use serde::Serialize; +use thiserror::Error; +use utoipa::ToSchema; + +#[derive(Clone, Debug, Serialize, ToSchema)] +pub struct BackendInfo { + /// Mandatory + #[schema(example = "cuda")] + pub model_device_type: String, + #[schema(example = "torch.float16")] + pub model_dtype: String, + + /// Backend parameters + #[schema(example = "1")] + pub speculate: usize, + #[schema(example = "1.2")] + pub waiting_served_ratio: f32, + #[schema(example = "32000")] + pub max_batch_total_tokens: u32, + #[schema(example = "20")] + pub max_waiting_tokens: usize, + #[schema(nullable = true, example = "null")] + pub max_batch_size: Option, +} + +#[allow(clippy::too_many_arguments)] +pub async fn connect_backend( + max_input_tokens: usize, + max_total_tokens: usize, + master_shard_uds_path: String, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: Option, + max_waiting_tokens: usize, + max_batch_size: Option, +) -> Result<(BackendV2, BackendInfo), V2Error> { + // Helper function + let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option| { + match max_supported_batch_total_tokens { + // Older models do not support automatic max-batch-total-tokens + None => { + let max_batch_total_tokens = max_batch_total_tokens + .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))); + tracing::warn!("Model does not support automatic max batch total tokens"); + Ok(max_batch_total_tokens) + } + // Flash attention models return their max supported total tokens + Some(max_supported_batch_total_tokens) => { + // Warn if user added his own max-batch-total-tokens as we will ignore it + if max_batch_total_tokens.is_some() { + tracing::warn!( + "`--max-batch-total-tokens` is deprecated for Flash \ + Attention models." + ); + tracing::warn!( + "Inferred max batch total tokens: {max_supported_batch_total_tokens}" + ); + } + if max_total_tokens as u32 > max_supported_batch_total_tokens { + return Err(V2Error::NotEnoughMemory(max_total_tokens)); + } + + Ok(max_supported_batch_total_tokens) + } + } + }; + + let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) + .await + .map_err(V2Error::Connection)?; + + // server is running on v2 + // Clear the cache; useful if the webserver rebooted + sharded_client + .clear_cache(None) + .await + .map_err(V2Error::Cache)?; + // Get info from the shard + let shard_info = sharded_client.info().await.map_err(V2Error::Info)?; + + // Warmup model + tracing::info!("Warming up model"); + let max_batch_total_tokens = check_max_batch_total_tokens( + sharded_client + .warmup( + max_input_tokens as u32, + max_batch_prefill_tokens, + max_total_tokens as u32, + max_batch_total_tokens.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))), + max_batch_size, + ) + .await + .map_err(V2Error::Warmup)?, + )?; + tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); + + let backend_info = BackendInfo { + waiting_served_ratio, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + model_device_type: shard_info.device_type.clone(), + model_dtype: shard_info.dtype.clone(), + speculate: shard_info.speculate as usize, + }; + + let backend = BackendV2::new( + sharded_client, + waiting_served_ratio, + max_input_tokens as u32, + max_total_tokens as u32, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + shard_info.requires_padding, + shard_info.window_size, + shard_info.speculate, + ); + + tracing::info!("Using backend V3"); + + Ok((backend, backend_info)) +} + +#[derive(Debug, Error)] +pub enum V2Error { + #[error("Unable to clear the Python model shards cache: {0}")] + Cache(ClientError), + #[error("Unable to connect to the Python model shards: {0}")] + Connection(ClientError), + #[error("Unable to get the Python model shards info: {0}")] + Info(ClientError), + #[error("Unable to warmup the Python model shards: {0}")] + Warmup(ClientError), + #[error("Not enough memory to handle `max_total_tokens={0}`")] + NotEnoughMemory(usize), +} diff --git a/backends/v2/src/main.rs b/backends/v2/src/main.rs new file mode 100644 index 00000000..f53d898e --- /dev/null +++ b/backends/v2/src/main.rs @@ -0,0 +1,212 @@ +use clap::{Parser, Subcommand}; +use text_generation_router::{server, usage_stats}; +use text_generation_router_v2::{connect_backend, V2Error}; +use thiserror::Error; + +/// App Configuration +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + #[command(subcommand)] + command: Option, + + #[clap(default_value = "128", long, env)] + max_concurrent_requests: usize, + #[clap(default_value = "2", long, env)] + max_best_of: usize, + #[clap(default_value = "4", long, env)] + max_stop_sequences: usize, + #[clap(default_value = "5", long, env)] + max_top_n_tokens: u32, + #[clap(default_value = "1024", long, env)] + max_input_tokens: usize, + #[clap(default_value = "2048", long, env)] + max_total_tokens: usize, + #[clap(default_value = "1.2", long, env)] + waiting_served_ratio: f32, + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, + #[clap(default_value = "20", long, env)] + max_waiting_tokens: usize, + #[clap(long, env)] + max_batch_size: Option, + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + #[clap(default_value = "3000", long, short, env)] + port: u16, + #[clap(default_value = "/tmp/text-generation-server-0", long, env)] + master_shard_uds_path: String, + #[clap(default_value = "bigscience/bloom", long, env)] + tokenizer_name: String, + #[clap(long, env)] + tokenizer_config_path: Option, + #[clap(long, env)] + revision: Option, + #[clap(default_value = "2", long, env)] + validation_workers: usize, + #[clap(long, env)] + api_key: Option, + #[clap(long, env)] + json_output: bool, + #[clap(long, env)] + otlp_endpoint: Option, + #[clap(default_value = "text-generation-inference.router", long, env)] + otlp_service_name: String, + #[clap(long, env)] + cors_allow_origin: Option>, + #[clap(long, env)] + ngrok: bool, + #[clap(long, env)] + ngrok_authtoken: Option, + #[clap(long, env)] + ngrok_edge: Option, + #[clap(long, env, default_value_t = false)] + messages_api_enabled: bool, + #[clap(long, env, default_value_t = false)] + disable_grammar_support: bool, + #[clap(default_value = "4", long, env)] + max_client_batch_size: usize, + #[clap(default_value = "on", long, env)] + usage_stats: usage_stats::UsageStatsLevel, +} + +#[derive(Debug, Subcommand)] +enum Commands { + PrintSchema, +} + +#[tokio::main] +async fn main() -> Result<(), RouterError> { + // Get args + let args = Args::parse(); + // Pattern match configuration + let Args { + command, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + hostname, + port, + master_shard_uds_path, + tokenizer_name, + tokenizer_config_path, + revision, + validation_workers, + api_key, + json_output, + otlp_endpoint, + otlp_service_name, + cors_allow_origin, + ngrok, + ngrok_authtoken, + ngrok_edge, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + usage_stats, + } = args; + + if let Some(Commands::PrintSchema) = command { + use utoipa::OpenApi; + let api_doc = text_generation_router::server::ApiDoc::openapi(); + let api_doc = serde_json::to_string_pretty(&api_doc).unwrap(); + println!("{}", api_doc); + std::process::exit(0); + }; + text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output); + + // Validate args + if max_input_tokens >= max_total_tokens { + return Err(RouterError::ArgumentValidation( + "`max_input_tokens` must be < `max_total_tokens`".to_string(), + )); + } + if max_input_tokens as u32 > max_batch_prefill_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); + } + + if validation_workers == 0 { + return Err(RouterError::ArgumentValidation( + "`validation_workers` must be > 0".to_string(), + )); + } + + if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { + if max_batch_prefill_tokens > *max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); + } + if max_total_tokens as u32 > *max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); + } + } + + if let Some(max_batch_size) = max_batch_size { + if max_batch_size == 0 { + return Err(RouterError::ArgumentValidation( + "`max_batch_size` must be > 0".to_string(), + )); + } + } + + let (backend, _backend_info) = connect_backend( + max_input_tokens, + max_total_tokens, + master_shard_uds_path, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + ) + .await?; + + // Run server + server::run( + backend, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + validation_workers, + api_key, + tokenizer_name, + tokenizer_config_path, + revision, + hostname, + port, + cors_allow_origin, + ngrok, + ngrok_authtoken, + ngrok_edge, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + usage_stats, + ) + .await?; + Ok(()) +} + +#[derive(Debug, Error)] +enum RouterError { + #[error("Argument validation error: {0}")] + ArgumentValidation(String), + #[error("Backend failed: {0}")] + Backend(#[from] V2Error), + #[error("WebServer error: {0}")] + WebServer(#[from] server::WebServerError), + #[error("Tokio runtime failed to start: {0}")] + Tokio(#[from] std::io::Error), +} diff --git a/router/src/queue.rs b/backends/v2/src/queue.rs similarity index 76% rename from router/src/queue.rs rename to backends/v2/src/queue.rs index 11690bf7..5f793d09 100644 --- a/router/src/queue.rs +++ b/backends/v2/src/queue.rs @@ -1,15 +1,14 @@ -/// Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - -use crate::infer::InferError; -use crate::infer::InferStreamResponse; -use crate::validation::ValidGenerateRequest; +use crate::client::{ + Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, +}; use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::min; -use std::cmp::{Eq, Ord, PartialEq, PartialOrd}; -use std::collections::BinaryHeap; -use std::env; -use std::time::Duration; -use text_generation_client::{Batch, Request}; +use std::collections::VecDeque; +use text_generation_router::infer::InferError; +use text_generation_router::infer::InferStreamResponse; +use text_generation_router::validation::{ + ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, +}; use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; use tracing::{info_span, instrument, Span}; @@ -41,11 +40,11 @@ pub(crate) struct Queue { impl Queue { pub(crate) fn new( requires_padding: bool, - max_input_length: u32, - max_total_tokens: u32, block_size: u32, window_size: Option, speculate: u32, + max_input_tokens: u32, + max_total_tokens: u32, ) -> Self { // Create channel let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); @@ -53,18 +52,17 @@ impl Queue { // Launch background queue task tokio::spawn(queue_task( requires_padding, - max_input_length, - max_total_tokens, block_size, window_size, speculate, + max_input_tokens, + max_total_tokens, queue_receiver, )); Self { queue_sender } } - /// Append an entry to the queue #[instrument(skip_all)] pub(crate) fn append(&self, entry: Entry) { // Send append command to the background task managing the state @@ -106,27 +104,27 @@ impl Queue { // Background task responsible of the queue state async fn queue_task( requires_padding: bool, - max_input_length: u32, - max_total_tokens: u32, block_size: u32, window_size: Option, speculate: u32, + max_input_tokens: u32, + max_total_tokens: u32, mut receiver: mpsc::UnboundedReceiver, ) { let mut state = State::new( requires_padding, - max_input_length, - max_total_tokens, block_size, window_size, - speculate + speculate, + max_input_tokens, + max_total_tokens, ); while let Some(cmd) = receiver.recv().await { match cmd { QueueCommand::Append(entry, span) => { span.in_scope(|| state.append(*entry)); - metrics::increment_gauge!("tgi_queue_size", 1.0); + metrics::gauge!("tgi_queue_size").increment(1.0); } QueueCommand::NextBatch { min_size, @@ -139,110 +137,17 @@ async fn queue_task( let next_batch = state.next_batch(min_size, max_size, prefill_token_budget, token_budget); response_sender.send(next_batch).unwrap(); - metrics::gauge!("tgi_queue_size", state.entries.len() as f64); + metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); }), } } } -#[derive(Debug)] -struct IdentifiableEntry(u64, Entry); - -impl Eq for IdentifiableEntry {} - -impl PartialEq for IdentifiableEntry { - fn eq(&self, other: &Self) -> bool { - self.0 == other.0 - } -} - -impl Ord for IdentifiableEntry { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - let ordering = match self - .1 - .request - .input_length - .cmp(&other.1.request.input_length) - { - std::cmp::Ordering::Equal => self.0.cmp(&other.0), - any => any, - }; - - // inverse to get min heap - return ordering.reverse(); - } -} - -impl PartialOrd for IdentifiableEntry { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -#[derive(Debug)] -struct QueueImpl { - regular_entries: BinaryHeap, - overdue_entries: BinaryHeap, - overdue_threshold: Duration, -} - -impl QueueImpl { - fn new(capacity: usize, overdue_threshold: Duration) -> Self { - Self { - regular_entries: BinaryHeap::with_capacity(capacity), - overdue_entries: BinaryHeap::with_capacity(capacity), - overdue_threshold, - } - } - - fn update(&mut self) { - if self.regular_entries.is_empty() { - return; - } - - let mut left = BinaryHeap::with_capacity(self.regular_entries.capacity()); - - for entry in self.regular_entries.drain() { - if entry.1.queue_time.elapsed() > self.overdue_threshold { - self.overdue_entries.push(entry); - } else { - left.push(entry); - } - } - - self.regular_entries = left; - } - - fn push(&mut self, entry: IdentifiableEntry) { - if entry.1.queue_time.elapsed() > self.overdue_threshold { - self.overdue_entries.push(entry); - } else { - self.regular_entries.push(entry); - } - } - - fn pop(&mut self) -> Option { - if !self.overdue_entries.is_empty() { - self.overdue_entries.pop() - } else { - self.regular_entries.pop() - } - } - - fn is_empty(&self) -> bool { - self.regular_entries.is_empty() && self.overdue_entries.is_empty() - } - - fn len(&self) -> usize { - self.regular_entries.len() + self.overdue_entries.len() - } -} - /// Queue State #[derive(Debug)] struct State { - /// Queue entries - entries: QueueImpl, + /// Queue entries organized in a Vec + entries: VecDeque<(u64, Entry)>, /// Id of the next entry next_id: u64, @@ -253,12 +158,6 @@ struct State { /// Whether the model is using padding requires_padding: bool, - /// Maximum input length, required for padding scenario - max_input_length: u32, - - /// Maximum input and output length, required for padding scenario - max_total_tokens: u32, - /// Paged Attention block size block_size: u32, @@ -267,33 +166,33 @@ struct State { /// Speculation amount speculate: u32, + + /// max input tokens + max_input_tokens: u32, + + /// max total tokens, + max_total_tokens: u32, } impl State { fn new( requires_padding: bool, - max_input_length: u32, - max_total_tokens: u32, block_size: u32, window_size: Option, speculate: u32, + max_input_tokens: u32, + max_total_tokens: u32, ) -> Self { - let default_threshold: u64 = 120; - let threshold: u64 = match env::var("QUEUE_THRESHOLD_MS") { - Ok(val) => val.parse().unwrap_or(default_threshold), - Err(_) => default_threshold, - }; - Self { - entries: QueueImpl::new(128, Duration::from_millis(threshold)), + entries: VecDeque::with_capacity(128), next_id: 0, next_batch_id: 0, requires_padding, - max_input_length, - max_total_tokens, block_size, window_size, speculate, + max_input_tokens, + max_total_tokens, } } @@ -304,7 +203,7 @@ impl State { entry.temp_span = Some(queue_span); // Push entry in the queue - self.entries.push(IdentifiableEntry(self.next_id, entry)); + self.entries.push_back((self.next_id, entry)); self.next_id += 1; } @@ -329,11 +228,20 @@ impl State { } } - self.entries.update(); + if let Some(max_size) = max_size { + if max_size == 0 { + tracing::debug!("No capacity"); + return None; + } + } + + // Pad prefill_token_budget to be a multiple of block size + let prefill_token_budget = + ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size; // Create span for this batch to add context to inference calls let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); - next_batch_span.follows_from(&Span::current()); + next_batch_span.follows_from(Span::current()); let mut batch_requests = Vec::with_capacity(self.entries.len()); let mut batch_entries = @@ -343,11 +251,11 @@ impl State { let mut decode_tokens: u32 = 0; // Pop entries starting from the front of the queue - while let Some(IdentifiableEntry(id, mut entry)) = self.entries.pop() { + while let Some((id, mut entry)) = self.entries.pop_front() { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_closed() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); tracing::debug!("Dropping entry"); continue; } @@ -355,7 +263,7 @@ impl State { if self.requires_padding { // We pad to max input length in the Python shards // We need to take these padding tokens into the equation - prefill_tokens = (batch_requests.len() + 1) as u32 * self.max_input_length; + prefill_tokens = (batch_requests.len() + 1) as u32 * self.max_input_tokens } else { // pad to block size prefill_tokens += ((entry.request.input_length + self.block_size - 1) @@ -364,9 +272,7 @@ impl State { } if self.requires_padding { - // We pad to max total tokens in the Python shards - // We need to take these padding tokens into the equation - decode_tokens = (batch_requests.len() + 1) as u32 * (self.max_total_tokens - self.max_input_length); + decode_tokens = (batch_requests.len() + 1) as u32 * (self.max_total_tokens - self.max_input_tokens); } else { let max_new_tokens = match self.window_size { None => entry.request.stopping_parameters.max_new_tokens, @@ -387,7 +293,7 @@ impl State { // Entry is over budget // Add it back to the front tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); - self.entries.push(IdentifiableEntry(id, entry)); + self.entries.push_front((id, entry)); break; } @@ -403,10 +309,14 @@ impl State { batch_requests.push(Request { id, prefill_logprobs: entry.request.decoder_input_details, - inputs: entry.request.inputs.clone(), + inputs: entry.request.inputs.chunks_to_string(), truncate: entry.request.truncate, - parameters: Some(entry.request.parameters.clone()), - stopping_parameters: Some(entry.request.stopping_parameters.clone()), + parameters: Some(NextTokenChooserParameters::from( + entry.request.parameters.clone(), + )), + stopping_parameters: Some(StoppingCriteriaParameters::from( + entry.request.stopping_parameters.clone(), + )), top_n_tokens: entry.request.top_n_tokens, }); // Set batch_time @@ -422,7 +332,7 @@ impl State { // Empty batch if batch_requests.is_empty() { - tracing::debug!("Filterered out all entries"); + tracing::debug!("Filtered out all entries"); return None; } @@ -434,7 +344,7 @@ impl State { for r in batch_requests.into_iter().rev() { let id = r.id; let entry = batch_entries.remove(&id).unwrap(); - self.entries.push(IdentifiableEntry(id, entry)); + self.entries.push_front((id, entry)); } return None; @@ -454,7 +364,7 @@ impl State { // Increment batch id self.next_batch_id += 1; - metrics::histogram!("tgi_batch_next_size", batch.size as f64); + metrics::histogram!("tgi_batch_next_size").record(batch.size as f64); Some((batch_entries, batch, next_batch_span)) } @@ -475,26 +385,49 @@ enum QueueCommand { }, } +impl From for NextTokenChooserParameters { + fn from(value: ValidParameters) -> Self { + let (grammar, grammar_type) = match value.grammar { + None => (String::new(), GrammarType::None), + + Some(grammar) => match grammar { + ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json), + ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex), + }, + }; + + Self { + temperature: value.temperature, + top_k: value.top_k, + top_p: value.top_p, + typical_p: value.typical_p, + do_sample: value.do_sample, + seed: value.seed, + repetition_penalty: value.repetition_penalty, + frequency_penalty: value.frequency_penalty, + watermark: value.watermark, + grammar, + grammar_type: grammar_type.into(), + } + } +} + +impl From for StoppingCriteriaParameters { + fn from(value: ValidStoppingParameters) -> Self { + Self { + max_new_tokens: value.max_new_tokens, + stop_sequences: value.stop_sequences, + ignore_eos_token: value.ignore_eos_token, + } + } +} + #[cfg(test)] mod tests { use super::*; - use text_generation_client::{ - GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters, - }; + use std::sync::Arc; use tracing::info_span; - fn default_queue() -> Queue { - Queue::new( - true, 1, 2, 1, None, 0 - ) - } - - fn default_state() -> State { - State::new( - true, 1, 2, 1, None, 0 - ) - } - fn default_entry() -> ( Entry, mpsc::UnboundedReceiver>, @@ -503,11 +436,13 @@ mod tests { let entry = Entry { request: ValidGenerateRequest { - inputs: String::new(), + inputs: vec![], + input_ids: Some(Arc::new(vec![])), input_length: 0, + add_special_tokens: true, truncate: 0, decoder_input_details: false, - parameters: NextTokenChooserParameters { + parameters: ValidParameters { temperature: 0.0, top_k: 0, top_p: 0.0, @@ -517,15 +452,15 @@ mod tests { repetition_penalty: 0.0, frequency_penalty: 0.0, watermark: false, - grammar: String::new(), - grammar_type: ProtoGrammarType::None as i32, + grammar: None, }, - stopping_parameters: StoppingCriteriaParameters { + stopping_parameters: ValidStoppingParameters { ignore_eos_token: false, max_new_tokens: 1, stop_sequences: vec![], }, top_n_tokens: 0, + adapter_id: None, }, response_tx, span: info_span!("entry"), @@ -538,7 +473,7 @@ mod tests { #[test] fn test_append() { - let mut state = default_state(); + let mut state = State::new(false, 1, None, 0); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -548,13 +483,13 @@ mod tests { assert_eq!(state.next_id, 1); assert_eq!(state.entries.len(), 1); - let id = state.entries.pop().unwrap().0; + let (id, _) = state.entries.remove(0).unwrap(); assert_eq!(id, 0); } #[test] fn test_next_batch_empty() { - let mut state = default_state(); + let mut state = State::new(false, 1, None, 0); assert!(state.next_batch(None, None, 1, 1).is_none()); assert!(state.next_batch(Some(1), None, 1, 1).is_none()); @@ -562,13 +497,13 @@ mod tests { #[test] fn test_next_batch_min_size() { - let mut state = default_state(); + let mut state = State::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, None, 2, 4).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -588,13 +523,13 @@ mod tests { assert_eq!(state.next_id, 3); assert_eq!(state.entries.len(), 1); - let IdentifiableEntry(id, _) = state.entries.pop().unwrap(); + let (id, _) = state.entries.remove(0).unwrap(); assert_eq!(id, 2); } #[test] fn test_next_batch_max_size() { - let mut state = default_state(); + let mut state = State::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -614,13 +549,13 @@ mod tests { #[test] fn test_next_batch_token_budget() { - let mut state = default_state(); + let mut state = State::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, None, 1, 2).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert_eq!(batch.id, 0); @@ -633,7 +568,7 @@ mod tests { let (entry3, _guard3) = default_entry(); state.append(entry3); - let (entries, batch, _) = state.next_batch(None, None, 3, 6).unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&1)); assert!(entries.contains_key(&2)); @@ -647,14 +582,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = default_queue(); + let queue = Queue::new(false, 1, None, 0); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = default_queue(); + let queue = Queue::new(false, 1, None, 0); assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -662,13 +597,13 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = default_queue(); + let queue = Queue::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); queue.append(entry2); - let (entries, batch, _) = queue.next_batch(None, None, 2, 4).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -685,7 +620,7 @@ mod tests { // Not enough token budget assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none()); // Ok - let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 4).await.unwrap(); + let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap(); assert_eq!(entries2.len(), 1); assert!(entries2.contains_key(&2)); assert!(entries2.get(&2).unwrap().batch_time.is_some()); @@ -695,7 +630,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_max_size() { - let queue = default_queue(); + let queue = Queue::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -711,13 +646,13 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = default_queue(); + let queue = Queue::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); queue.append(entry2); - let (entries, batch, _) = queue.next_batch(None, None, 1, 2).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert_eq!(batch.id, 0); @@ -726,7 +661,7 @@ mod tests { let (entry3, _guard3) = default_entry(); queue.append(entry3); - let (entries, batch, _) = queue.next_batch(None, None, 3, 6).await.unwrap(); + let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&1)); assert!(entries.contains_key(&2)); @@ -736,7 +671,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(true, 1, 2, 1, None, 2); + let queue = Queue::new(false, 1, None, 2); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -755,7 +690,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = default_queue(); + let queue = Queue::new(false, 1, None, 0); let (entry, _) = default_entry(); queue.append(entry); diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml new file mode 100644 index 00000000..69dad072 --- /dev/null +++ b/backends/v3/Cargo.toml @@ -0,0 +1,83 @@ +[package] +name = "text-generation-router-v3" +description = "Text Generation Webserver" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[lib] +path = "src/lib.rs" + +[[bin]] +name = "text-generation-router" +path = "src/main.rs" + +[dependencies] +async-trait = "0.1.74" +async-stream = "0.3.5" +axum = { version = "0.7", features = ["json"] } +axum-tracing-opentelemetry = "0.16" +text-generation-router = { path = "../../router" } +clap = { version = "4.4.5", features = ["derive", "env"] } +grpc-metadata = { path = "../grpc-metadata" } +futures = "0.3.28" +hf-hub = { workspace = true } +jsonschema = { version = "0.17.1", features = ["draft202012"] } +metrics = { workspace = true } +metrics-exporter-prometheus = { workspace = true } +nohash-hasher = "0.2.0" +opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } +opentelemetry-otlp = "0.13.0" +rand = "0.8.5" +reqwest = { version = "0.11.20", features = [] } +serde = "1.0.188" +serde_json = "1.0.107" +slotmap = "1.0.7" +thiserror = "1.0.48" +tokenizers = { workspace = true } +tokio = { version = "1.32.0", features = [ + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", +] } +tokio-stream = "0.1.14" +tower-http = { version = "0.5.1", features = ["cors"] } +tracing = "0.1.37" +tracing-opentelemetry = "0.21.0" +tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } +utoipa = { version = "4.2.0", features = ["axum_extras"] } +utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } +init-tracing-opentelemetry = { version = "0.14.1", features = [ + "opentelemetry-otlp", +] } +minijinja = { workspace = true } +minijinja-contrib = { workspace = true } +futures-util = "0.3.30" +regex = "1.10.3" +once_cell = "1.19.0" +image = "0.25.1" +base64 = { workspace = true } +prost = "^0.12" +tonic = "^0.10" +tower = "^0.4" + +[build-dependencies] +tonic-build = "0.10.1" +prost-build = "0.12.1" + +[dev-dependencies] +criterion = "0.3" +itertools = "0.13" + +[features] +default = ["ngrok"] +ngrok = ["text-generation-router/ngrok"] +google = ["text-generation-router/google"] +kserve = ["text-generation-router/kserve"] + +[[bench]] +name = "prefix_cache" +harness = false diff --git a/backends/v3/benches/prefix_cache.rs b/backends/v3/benches/prefix_cache.rs new file mode 100644 index 00000000..d9df45b2 --- /dev/null +++ b/backends/v3/benches/prefix_cache.rs @@ -0,0 +1,47 @@ +use std::sync::Arc; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::Rng; + +use text_generation_router_v3::block_allocator::Allocator; +use text_generation_router_v3::radix::RadixAllocator; + +fn prefix_cache_benchmark(c: &mut Criterion) { + // let prefixes: Vec> = (0..8192) + // .chunks(256) + // .into_iter() + // .map(|c| c.collect()) + // .collect(); + + let mut cache = RadixAllocator::new(1, 262144, None); + + c.bench_function("Radix allocator", |b| { + b.iter_batched( + || { + //prefixes + // .choose_multiple(&mut rand::thread_rng(), 5) + // .fold(Vec::new(), |mut v, s| { + // v.extend(s); + // v + // }) + + (0..7936) + .map(|_| rand::thread_rng().gen_range(0..1024)) + .collect::>() + }, + |prefill| { + let alloc = cache.allocate( + prefill.len() as u32 + 13, + Some(Arc::new(black_box(prefill))), + ); + if let Some(alloc) = alloc { + cache.free(alloc.blocks.clone(), alloc.allocation_id); + } + }, + criterion::BatchSize::SmallInput, + ); + }); +} + +criterion_group!(benches, prefix_cache_benchmark); +criterion_main!(benches); diff --git a/backends/v3/build.rs b/backends/v3/build.rs new file mode 100644 index 00000000..6d702d14 --- /dev/null +++ b/backends/v3/build.rs @@ -0,0 +1,19 @@ +use std::fs; + +fn main() -> Result<(), Box> { + println!("cargo:rerun-if-changed=../../proto/"); + + fs::create_dir_all("src/client/pb").unwrap_or(()); + let mut config = prost_build::Config::new(); + config.protoc_arg("--experimental_allow_proto3_optional"); + + tonic_build::configure() + .build_client(true) + .build_server(false) + .out_dir("src/client/pb") + .include_file("mod.rs") + .compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"]) + .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); + + Ok(()) +} diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs new file mode 100644 index 00000000..b1268152 --- /dev/null +++ b/backends/v3/src/backend.rs @@ -0,0 +1,518 @@ +use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient}; +/// Batching and inference logic +use crate::queue::{Entry, Queue}; +use async_trait::async_trait; +use nohash_hasher::IntMap; +use std::sync::Arc; +use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; +use text_generation_router::validation::ValidGenerateRequest; +use text_generation_router::{Attention, FinishReason, PrefillToken, Token}; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::{mpsc, Notify}; +use tokio::time::Instant; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::{info_span, instrument, Instrument, Span}; + +pub struct BackendV3 { + /// Request queue + queue: Queue, + /// Notify batcher on queue appends + batching_task_notifier: Arc, + /// Client clone, used for health checks to skip the queue + client: ShardedClient, +} + +impl BackendV3 { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + client: ShardedClient, + waiting_served_ratio: f32, + max_input_tokens: u32, + max_total_tokens: u32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: u32, + max_waiting_tokens: usize, + max_batch_size: Option, + requires_padding: bool, + window_size: Option, + speculate: u32, + ) -> Self { + let prefix_caching = + std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var"); + let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1"); + let attention: String = std::env::var("ATTENTION").expect("attention env var"); + + let attention: Attention = attention + .parse() + .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")); + let block_size = attention.block_size(); + + let queue = Queue::new( + requires_padding, + block_size, + prefix_caching, + window_size, + speculate, + max_input_tokens, + max_total_tokens, + max_batch_total_tokens, + ); + let batching_task_notifier = Arc::new(Notify::new()); + + // Spawn batching background task that contains all the inference logic + tokio::spawn(batching_task( + client.clone(), + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + queue.clone(), + batching_task_notifier.clone(), + )); + + Self { + queue, + batching_task_notifier, + client, + } + } +} + +#[async_trait] +impl Backend for BackendV3 { + #[instrument(skip_all)] + fn schedule( + &self, + request: ValidGenerateRequest, + ) -> Result>, InferError> { + // MPSC channel to communicate with the background batching task + let (response_tx, response_rx) = mpsc::unbounded_channel(); + + // Append the request to the queue + self.queue.append(Entry { + request, + response_tx, + span: Span::current(), + temp_span: None, + queue_time: Instant::now(), + batch_time: None, + block_allocation: None, + }); + + // Notify the background task that we have a new entry in the queue that needs + // to be batched + self.batching_task_notifier.notify_one(); + + // Return stream + Ok(UnboundedReceiverStream::new(response_rx)) + } + + async fn health(&self, current_health: bool) -> bool { + if current_health { + // Generation is healthy, we only check that the shards can allocate on device + self.client.device_health().await + } else { + self.client.model_health().await + } + .is_ok() + } +} + +/// Batching logic +/// Will be launched in a background Tokio task +/// +/// Batches requests and sends them to the inference server +#[allow(clippy::too_many_arguments)] +pub(crate) async fn batching_task( + mut client: ShardedClient, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: u32, + max_waiting_tokens: usize, + max_batch_size: Option, + queue: Queue, + notifier: Arc, +) { + // Infinite loop + loop { + // Wait for a notification from the Infer struct + notifier.notified().await; + + // Get the next batch from the queue + // This batch might be smaller than the maximum batch size if there are not enough requests + // waiting in the queue + while let Some((mut entries, batch, span)) = queue + .next_batch( + None, + max_batch_size, + max_batch_prefill_tokens, + max_batch_total_tokens, + ) + .await + { + let mut cached_batch = prefill(&mut client, batch, &mut entries) + .instrument(span) + .await; + let mut waiting_tokens = 1; + + // We loop until we do not receive any cached batch from the inference server (== until + // all requests have met their stopping criteria) + while let Some(batch) = cached_batch { + // Get current batch info + let batch_size = batch.size; + let batch_max_tokens = batch.max_tokens; + let mut batches = vec![batch]; + metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); + metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); + + let min_size = if waiting_tokens >= max_waiting_tokens { + // If we didn't onboard any new requests since >= max_waiting_tokens, we try + // to add a new batch even though its size might be small + None + } else { + // Minimum batch size + // TODO: temporarily disable to avoid incorrect deallocation + + // reallocation when using prefix caching. + Some((batch_size as f32 * waiting_served_ratio).floor() as usize) + }; + + let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); + let max_size = + max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); + + // Try to get a new batch + if let Some((mut new_entries, new_batch, span)) = queue + .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) + .await + { + // Tracking metrics + if min_size.is_some() { + metrics::counter!("tgi_batch_concat", "reason" => "backpressure") + .increment(1); + } else { + metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") + .increment(1); + } + + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to add the info that this entry is waiting + // because a new batch is being computed + let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); + // Add relationships + span.follows_from(&entry_waiting_span); + entry_waiting_span.follows_from(&span); + // Update entry + entry.temp_span = Some(entry_waiting_span); + }); + + // Generate one token for this new batch to have the attention past in cache + let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) + .instrument(span) + .await; + // Reset waiting counter + waiting_tokens = 1; + // Extend current batch with the new batch + if let Some(new_cached_batch) = new_cached_batch { + entries.extend(new_entries); + batches.push(new_cached_batch); + } + } + + // Create span for this batch to add context to inference calls + let next_batch_size = entries.len(); + let next_batch_span = + info_span!(parent: None, "batch", batch_size = next_batch_size); + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to link the batch back to this entry + let entry_batch_span = info_span!(parent: &entry.span, "infer"); + // Add relationships + next_batch_span.follows_from(&entry_batch_span); + entry_batch_span.follows_from(&next_batch_span); + // Update entry + entry.temp_span = Some(entry_batch_span); + }); + + cached_batch = decode(&mut client, batches, &mut entries) + .instrument(next_batch_span) + .await; + waiting_tokens += 1; + } + metrics::gauge!("tgi_batch_current_size").set(0.0); + metrics::gauge!("tgi_batch_current_max_tokens").set(0.0); + } + } +} + +#[instrument(skip_all)] +async fn prefill( + client: &mut ShardedClient, + batch: Batch, + entries: &mut IntMap, +) -> Option { + let start_time = Instant::now(); + let batch_id = batch.id; + metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); + + match client.prefill(batch).await { + Ok((generations, next_batch, timings)) => { + let start_filtering_time = Instant::now(); + // Send generated tokens and filter stopped entries + filter_send_generations(generations, entries); + + // Filter next batch and remove requests that were stopped + let next_batch = filter_batch(client, next_batch, entries).await; + + metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + let _ = client.clear_cache(Some(batch_id)).await; + send_errors(err, entries); + metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1); + None + } + } +} + +#[instrument(skip_all)] +async fn decode( + client: &mut ShardedClient, + batches: Vec, + entries: &mut IntMap, +) -> Option { + let start_time = Instant::now(); + let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); + metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); + + match client.decode(batches).await { + Ok((generations, next_batch, timings)) => { + let start_filtering_time = Instant::now(); + // Send generated tokens and filter stopped entries + filter_send_generations(generations, entries); + + // Filter next batch and remove requests that were stopped + let next_batch = filter_batch(client, next_batch, entries).await; + + if let Some(concat_duration) = timings.concat { + metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") + .record(concat_duration.as_secs_f64()); + } + metrics::histogram!("tgi_batch_forward_duration", "method" => "decode") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "decode") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "decode") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "decode") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + for id in batch_ids { + let _ = client.clear_cache(Some(id)).await; + } + send_errors(err, entries); + metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1); + None + } + } +} + +/// Filter a `batch` and remove all requests not present in `entries` +#[instrument(skip_all)] +async fn filter_batch( + client: &mut ShardedClient, + next_batch: Option, + entries: &IntMap, +) -> Option { + let mut batch = next_batch?; + + // No need to filter + if batch.size as usize == entries.len() { + return Some(batch); + } + + let id = batch.id; + + // Retain only requests that are still in entries + batch.request_ids.retain(|id| entries.contains_key(id)); + + if batch.request_ids.is_empty() { + // All requests have been filtered out + // Next batch is now empty + // Clear it from the Python shards cache + // We unwrap here as we need to panic since we cannot recover if this method fails + client.clear_cache(Some(id)).await.unwrap(); + None + } else { + // Filter Python shard cache + // We unwrap here as we need to panic since we cannot recover if this method fails + client.filter_batch(id, batch.request_ids).await.unwrap() + } +} + +/// Send one or multiple `InferStreamResponse` to Infer for all `entries` +/// and filter entries +#[instrument(skip_all)] +fn filter_send_generations(generations: Vec, entries: &mut IntMap) { + generations.into_iter().for_each(|generation| { + let id = generation.request_id; + // Get entry + // We can `expect` here as the request id should always be in the entries + let entry = entries + .get(&id) + .expect("ID not found in entries. This is a bug."); + + // Create and enter a span to link this function back to the entry + let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); + // Send generation responses back to the infer task + // If the receive an error from the Flume channel, it means that the client dropped the + // request and we need to stop generating hence why we unwrap_or(true) + let stopped = send_responses(generation, entry).inspect_err(|_err| { + tracing::error!("Entry response channel error."); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); + }).unwrap_or(true); + if stopped { + entries.remove(&id).expect("ID not found in entries. This is a bug."); + } + }); +} + +/// Send responses through the `entry` response channel +fn send_responses( + generation: Generation, + entry: &Entry, +) -> Result>>> { + // Return directly if the channel is disconnected + if entry.response_tx.is_closed() { + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); + return Ok(true); + } + + let mut stopped = false; + + if let Some(prefill_tokens) = generation.prefill_tokens { + // Create Token objects + // We do that here instead of in the Python code as Rust for loops are faster + let prefill_tokens = prefill_tokens + .ids + .into_iter() + .zip(prefill_tokens.logprobs) + .zip(prefill_tokens.texts) + .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) + .collect(); + + // Send message + entry + .response_tx + .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; + } + + // Create last Token + let tokens_ = generation.tokens.expect("Non empty tokens in generation"); + let n = tokens_.ids.len(); + metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64); + let mut iterator = tokens_ + .ids + .into_iter() + .zip(tokens_.logprobs) + .zip(tokens_.texts) + .zip(tokens_.is_special) + .enumerate() + .peekable(); + while let Some((i, (((id, logprob), text), special))) = iterator.next() { + let token = Token { + id, + text, + logprob, + special, + }; + let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) { + top_tokens_ + .ids + .iter() + .zip(top_tokens_.logprobs.iter()) + .zip(top_tokens_.texts.iter()) + .zip(top_tokens_.is_special.iter()) + .map(|(((&id, &logprob), text), &special)| Token { + id, + text: text.to_string(), + logprob, + special, + }) + .collect() + } else { + vec![] + }; + match (&generation.generated_text, iterator.peek()) { + (Some(generated_text), None) => { + // Generation has ended + stopped = true; + // Send message + entry.response_tx.send(Ok(InferStreamResponse::End { + token, + top_tokens, + generated_text: GeneratedText::from(generated_text.clone()), + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + }))?; + } + _ => { + // Send message + entry + .response_tx + .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; + } + } + } + + Ok(stopped) +} + +/// Send errors to Infer for all `entries` +#[instrument(skip_all)] +fn send_errors(error: ClientError, entries: &mut IntMap) { + entries.drain().for_each(|(_, entry)| { + // Create and enter a span to link this function back to the entry + let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); + let err = InferError::GenerationError(error.to_string()); + metrics::counter!("tgi_request_failure", "err" => "generation").increment(1); + tracing::error!("{err}"); + + // unwrap_or is valid here as we don't care if the receiver is gone. + entry + .response_tx + .send(Err(err)) + .unwrap_or(()); + }); +} + +impl From for GeneratedText { + fn from(value: crate::client::GeneratedText) -> Self { + let v3_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap(); + let finish_reason = match v3_finish_reason { + crate::client::FinishReason::Length => FinishReason::Length, + crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken, + crate::client::FinishReason::StopSequence => FinishReason::StopSequence, + }; + + Self { + text: value.text, + generated_tokens: value.generated_tokens, + finish_reason, + seed: value.seed, + } + } +} diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs new file mode 100644 index 00000000..4fea172b --- /dev/null +++ b/backends/v3/src/block_allocator.rs @@ -0,0 +1,209 @@ +use std::sync::Arc; +use tokio::sync::{mpsc, oneshot}; + +use crate::radix::RadixAllocator; + +#[derive(Debug, Clone)] +pub struct BlockAllocation { + pub allocation_id: u64, + pub blocks: Vec, + pub slots: Vec, + + /// Prefix that was cached and for which the KV does not have to + /// be recomputed. + pub prefix_len: u32, + + pub(crate) block_allocator: Option, +} + +impl Drop for BlockAllocation { + fn drop(&mut self) { + if let Some(block_allocator) = self.block_allocator.as_mut() { + block_allocator.free(self.blocks.clone(), self.allocation_id) + } + } +} + +#[derive(Debug, Clone)] +pub struct BlockAllocator { + /// Channel to communicate with the background task + block_allocator: mpsc::UnboundedSender, +} + +impl BlockAllocator { + pub(crate) fn new( + max_batch_total_tokens: u32, + block_size: u32, + prefix_caching: bool, + window_size: Option, + ) -> Self { + // Create channel + let (sender, receiver) = mpsc::unbounded_channel(); + + // Launch background queue task + tokio::spawn(block_allocator_task( + max_batch_total_tokens / block_size, + block_size, + prefix_caching, + window_size, + receiver, + )); + + Self { + block_allocator: sender, + } + } + + pub(crate) async fn allocate( + &self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option { + let (response_sender, response_receiver) = oneshot::channel(); + self.block_allocator + .send(BlockAllocatorCommand::Allocate { + tokens, + prefill_tokens, + response_sender, + }) + .unwrap(); + + response_receiver.await.unwrap().map(|mut allocation| { + allocation.block_allocator = Some(self.clone()); + allocation + }) + } + + pub(crate) fn free(&self, blocks: Vec, allocation_id: u64) { + self.block_allocator + .send(BlockAllocatorCommand::Free { + allocation_id, + blocks, + }) + .unwrap(); + } +} + +async fn block_allocator_task( + blocks: u32, + block_size: u32, + prefix_caching: bool, + window_size: Option, + mut receiver: mpsc::UnboundedReceiver, +) { + let mut allocator: Box = if prefix_caching { + Box::new(RadixAllocator::new(block_size, blocks, window_size)) + } else { + Box::new(SimpleAllocator::new(blocks, block_size, window_size)) + }; + while let Some(cmd) = receiver.recv().await { + match cmd { + BlockAllocatorCommand::Free { + blocks, + allocation_id, + } => allocator.free(blocks, allocation_id), + BlockAllocatorCommand::Allocate { + tokens, + prefill_tokens, + response_sender, + } => { + response_sender + .send(allocator.allocate(tokens, prefill_tokens)) + .unwrap(); + } + } + } +} + +#[derive(Debug)] +enum BlockAllocatorCommand { + Free { + blocks: Vec, + allocation_id: u64, + }, + Allocate { + tokens: u32, + prefill_tokens: Option>>, + response_sender: oneshot::Sender>, + }, +} + +pub trait Allocator { + fn allocate( + &mut self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option; + + fn free(&mut self, blocks: Vec, allocation_id: u64); +} +pub struct SimpleAllocator { + free_blocks: Vec, + block_size: u32, + window_size: Option, +} + +impl SimpleAllocator { + fn new(blocks: u32, block_size: u32, window_size: Option) -> Self { + SimpleAllocator { + block_size, + // Block 0 is reserved for health checks + free_blocks: (1..blocks).collect(), + window_size, + } + } +} + +impl Allocator for SimpleAllocator { + fn allocate( + &mut self, + tokens: u32, + _prefill_tokens: Option>>, + ) -> Option { + // Apply window size + let (required_blocks, repeats) = { + let (tokens, repeats) = match self.window_size { + None => (tokens, 1), + Some(window_size) => { + let repeats = (tokens + window_size - 1) / window_size; + let tokens = core::cmp::min(tokens, window_size); + (tokens, repeats as usize) + } + }; + // Pad to a multiple of block size + let required_blocks = (tokens + self.block_size - 1) / self.block_size; + (required_blocks, repeats) + }; + + let tokens = tokens as usize; + if required_blocks > self.free_blocks.len() as u32 { + None + } else { + let blocks = self + .free_blocks + .split_off(self.free_blocks.len() - required_blocks as usize); + let mut slots = + Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize); + + 'slots: for block_id in blocks.repeat(repeats).iter() { + for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { + slots.push(s); + if slots.len() == tokens { + break 'slots; + } + } + } + Some(BlockAllocation { + allocation_id: 0, + blocks, + slots, + prefix_len: 0, + block_allocator: None, + }) + } + } + + fn free(&mut self, blocks: Vec, _allocation_id: u64) { + self.free_blocks.extend(blocks) + } +} diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs new file mode 100644 index 00000000..4508b92d --- /dev/null +++ b/backends/v3/src/client/grpc_client.rs @@ -0,0 +1,288 @@ +/// Single shard Client +use crate::client::{pb, Chunk}; +use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64}; +use base64::engine::general_purpose::STANDARD; +use base64::Engine; +use grpc_metadata::InjectTelemetryContext; +use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient; +use pb::generate::v3::*; +use std::cmp::min; +use std::time::Duration; +use tonic::transport::{Channel, Uri}; +use tracing::instrument; + +/// Text Generation Inference gRPC client +#[derive(Debug, Clone)] +pub struct Client { + stub: TextGenerationServiceClient, +} + +impl Client { + /// Returns a client connected to the given url + #[allow(dead_code)] + pub async fn connect(uri: Uri) -> Result { + let channel = Channel::builder(uri).connect().await?; + + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) + } + + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { + let channel = Channel::from_shared("http://[::]:50051".to_string()) + .unwrap() + .connect_with_connector(tower::service_fn(move |_: Uri| { + tokio::net::UnixStream::connect(path.clone()) + })) + .await?; + + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) + } + + /// Returns a list of uris or unix sockets of all shards + #[instrument(skip(self))] + pub async fn service_discovery(&mut self) -> Result> { + let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context(); + let response = self.stub.service_discovery(request).await.map_err(|_| { + ClientError::Connection("Server does not support v3 interface".to_string()) + })?; + let urls = response + .into_inner() + .urls + .into_iter() + // Remove unix socket prefix + .map(|url| match url.strip_prefix("unix://") { + None => url, + Some(stripped_url) => stripped_url.to_string(), + }) + .collect(); + Ok(urls) + } + + /// Get model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let request = tonic::Request::new(InfoRequest {}).inject_context(); + let response = self.stub.info(request).await?.into_inner(); + Ok(response) + } + + /// Get model health + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let request = tonic::Request::new(HealthRequest {}).inject_context(); + let response = self.stub.health(request).await?.into_inner(); + Ok(response) + } + + /// Clear the past generations cache + #[instrument(skip(self))] + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { + let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context(); + self.stub.clear_cache(request).await?; + Ok(()) + } + + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + request_ids: Vec, + ) -> Result> { + let request = tonic::Request::new(FilterBatchRequest { + batch_id, + request_ids, + }) + .inject_context(); + let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); + Ok(filtered_batch.batch) + } + + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip_all)] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + max_batch_total_tokens: u32, + max_batch_size: Option, + ) -> Result> { + let mut n_tokens = 0; + let mut requests = Vec::new(); + // Create requests + while n_tokens < max_prefill_tokens { + let truncate = min(max_input_length, max_prefill_tokens - n_tokens); + + let mut input_chunks = Vec::new(); + input_chunks + .push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into()); + if n_tokens == 0 { + input_chunks.push( + Chunk::Image(Image { + // Safe unwrap, because we control the data. + data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(), + mimetype: "image/jpeg;base64".to_string(), + }) + .into(), + ); + } + + // Send stringly-typed inputs for compatibility for backends that haven't + // been updated to support chunks. + + let mut inputs = String::new(); + inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); + if n_tokens == 0 { + // 1 request is enough to test vision heads. + // Sending images on other queries messes up easily with truncation. + inputs.push_str(&format!( + "![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})", + )); + } + + requests.push(Request { + id: 0, + inputs, + add_special_tokens: true, + input_chunks: Some(Input { + chunks: input_chunks, + }), + // We truncate the input on the server side to be sure that it has the correct size + truncate, + // Blocks and slots will be set on the server side if we use paged attention + blocks: vec![], + slots: vec![], + prefix_len: 0, + // Set sampling parameters to also take these ops into account in the max memory + parameters: Some(NextTokenChooserParameters { + temperature: 0.9, + top_k: 10, + top_p: 0.9, + typical_p: 0.9, + do_sample: false, + seed: 0, + repetition_penalty: 1.2, + frequency_penalty: 0.1, + watermark: true, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: max_total_tokens - truncate, + stop_sequences: vec![], + ignore_eos_token: true, + }), + prefill_logprobs: true, + top_n_tokens: 20, + adapter_id: None, + }); + n_tokens += max_input_length; + + // Check max_batch_size + if Some(requests.len()) == max_batch_size { + break; + } + } + + let batch = Batch { + id: 0, + size: requests.len() as u32, + requests, + max_tokens: max_input_length, + max_blocks: 0, + }; + + let request = tonic::Request::new(WarmupRequest { + batch: Some(batch), + max_input_length, + max_prefill_tokens, + max_total_tokens, + max_batch_total_tokens, + }) + .inject_context(); + let response = self.stub.warmup(request).await?.into_inner(); + Ok(response.max_supported_total_tokens) + } + + /// Generate one token for each request in the given batch + /// + /// Returns Generation for each request in batch + /// and the next cached batch + #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))] + pub async fn prefill( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option, PrefillTimings)> { + let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); + let response = self.stub.prefill(request).await?.into_inner(); + Ok(( + response.generations, + response.batch, + PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns), + )) + } + + /// Generate one token for each request in the given cached batches + /// + /// Returns Generation for each request in batches + /// and the next cached batch + #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::()))] + pub async fn decode( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option, DecodeTimings)> { + let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); + let response = self.stub.decode(request).await?.into_inner(); + Ok(( + response.generations, + response.batch, + DecodeTimings::new( + response.concat_ns, + response.forward_ns, + response.decode_ns, + response.total_ns, + ), + )) + } +} + +pub struct PrefillTimings { + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl PrefillTimings { + fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } + } +} + +pub struct DecodeTimings { + pub concat: Option, + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl DecodeTimings { + fn new(concat_ns: Option, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + concat: concat_ns.map(Duration::from_nanos), + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } + } +} diff --git a/backends/v3/src/client/mod.rs b/backends/v3/src/client/mod.rs new file mode 100644 index 00000000..755431f4 --- /dev/null +++ b/backends/v3/src/client/mod.rs @@ -0,0 +1,76 @@ +//! Text Generation gRPC client library + +use async_trait::async_trait; +use thiserror::Error; +use tonic::transport; +use tonic::Status; + +#[allow(clippy::derive_partial_eq_without_eq)] +mod pb; + +mod grpc_client; +mod sharded_client; + +pub use grpc_client::Client; +pub use pb::generate::v3::{ + input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, + HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, + StoppingCriteriaParameters, +}; +pub use sharded_client::ShardedClient; + +#[async_trait] +pub trait Health { + /// Check if a generate server is healthy by asking it to allocate a tensor on device + async fn device_health(&self) -> Result<()>; + + /// Check if a generate server is healthy by doing a forward pass. + /// EXPENSIVE + async fn model_health(&self) -> Result<()>; +} + +#[derive(Debug)] +pub struct ShardInfo { + pub requires_padding: bool, + pub dtype: String, + pub device_type: String, + pub window_size: Option, + pub speculate: u32, +} + +#[derive(Error, Debug, Clone)] +pub enum ClientError { + #[error("Could not connect to Text Generation server: {0}")] + Connection(String), + #[error("Server error: {0}")] + Generation(String), + #[error("Sharded results are empty")] + EmptyResults, +} + +impl From for ClientError { + fn from(err: Status) -> Self { + let err = Self::Generation(err.message().to_string()); + tracing::error!("{err}"); + err + } +} + +impl From for ClientError { + fn from(err: transport::Error) -> Self { + let err = Self::Connection(err.to_string()); + tracing::error!("{err}"); + err + } +} + +// Small convenience re-wrapping of `Chunk`. +impl From for InputChunk { + fn from(chunk: Chunk) -> Self { + InputChunk { chunk: Some(chunk) } + } +} + +static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; + +pub type Result = std::result::Result; diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs new file mode 100644 index 00000000..0b6e916e --- /dev/null +++ b/backends/v3/src/client/sharded_client.rs @@ -0,0 +1,264 @@ +use crate::client::{ClientError, Result}; +/// Multi shard Client +use crate::client::{Health, ShardInfo}; + +use crate::client::grpc_client::{DecodeTimings, PrefillTimings}; +use crate::client::{ + Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, +}; +use crate::client::{Chunk, InfoResponse, Input}; +use async_trait::async_trait; +use futures::future::join_all; +use tonic::transport::Uri; +use tracing::instrument; + +#[derive(Debug, Clone)] +/// Text Generation Inference gRPC multi client +pub struct ShardedClient { + clients: Vec, +} + +impl ShardedClient { + fn new(clients: Vec) -> Self { + Self { clients } + } + + /// Create a new ShardedClient from a master client. The master client will communicate with + /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method. + async fn from_master_client(mut master_client: Client) -> Result { + // Get all uris/unix sockets from the master client + let uris = master_client.service_discovery().await?; + let futures = uris.into_iter().map(Client::connect_uds); + let clients: Result> = join_all(futures).await.into_iter().collect(); + Ok(Self::new(clients?)) + } + + /// Returns a client connected to the given uri + #[allow(dead_code)] + pub async fn connect(uri: Uri) -> Result { + let master_client = Client::connect(uri).await?; + Self::from_master_client(master_client).await + } + + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { + let master_client = Client::connect_uds(path).await?; + Self::from_master_client(master_client).await + } + + /// Get the model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.info()) + .collect(); + join_all(futures).await.pop().unwrap().map(ShardInfo::from) + } + + /// GRPC health check + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.health()) + .collect(); + join_all(futures).await.pop().unwrap() + } + + /// Clear the past generations cache + #[instrument(skip(self))] + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.clear_cache(batch_id)) + .collect(); + join_all(futures).await.into_iter().collect() + } + + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + request_ids: Vec, + ) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone()))) + .collect(); + // all shards return the same message + join_all(futures).await.pop().unwrap() + } + + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip(self))] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + max_batch_total_tokens: u32, + max_batch_size: Option, + ) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| { + Box::pin(client.warmup( + max_input_length, + max_prefill_tokens, + max_total_tokens, + max_batch_total_tokens, + max_batch_size, + )) + }) + .collect(); + // Take the minimum value + let results = join_all(futures) + .await + .into_iter() + .collect::>>>()?; + Ok(results.into_iter().flatten().min()) + } + + /// Generate one token for each request in the given batch + /// + /// Returns Generation for each request in batch + /// and the next cached batch + #[instrument(skip_all, fields(id = & batch.id, size = & batch.size))] + pub async fn prefill( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option, PrefillTimings)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.prefill(batch.clone()))) + .collect(); + #[allow(clippy::type_complexity)] + let results: Result, Option, PrefillTimings)>> = + join_all(futures).await.into_iter().collect(); + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) + } + + /// Generate one token for each request in the given cached batches + /// + /// Returns Generation for each request in batches + /// and the next cached batch + #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))] + pub async fn decode( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option, DecodeTimings)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.decode(batches.clone()))) + .collect(); + #[allow(clippy::type_complexity)] + let results: Result, Option, DecodeTimings)>> = + join_all(futures).await.into_iter().collect(); + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) + } +} + +impl From for ShardInfo { + fn from(value: InfoResponse) -> Self { + Self { + requires_padding: value.requires_padding, + dtype: value.dtype, + device_type: value.device_type, + window_size: value.window_size, + speculate: value.speculate, + } + } +} + +#[async_trait] +impl Health for ShardedClient { + async fn device_health(&self) -> Result<()> { + self.clone().health().await?; + Ok(()) + } + + async fn model_health(&self) -> Result<()> { + // Dummy batch of 1 token and 1 generated token + let liveness_request = Request { + id: u64::MAX, + inputs: "liveness".to_string(), + input_chunks: Some(Input { + chunks: vec![Chunk::Text("liveness".into()).into()], + }), + truncate: 10, + add_special_tokens: true, + prefill_logprobs: false, + parameters: Some(NextTokenChooserParameters { + temperature: 1.0, + top_k: 0, + top_p: 1.0, + typical_p: 1.0, + do_sample: false, + seed: 0, + repetition_penalty: 1.0, + frequency_penalty: 0.0, + watermark: false, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: 1, + stop_sequences: vec![], + ignore_eos_token: false, + }), + top_n_tokens: 0, + // Block 0 is reserved for health checks + blocks: vec![0], + slots: (0..16).collect(), + prefix_len: 0, + adapter_id: None, + }; + let batch = Batch { + id: u64::MAX, + requests: vec![liveness_request], + size: 1, + max_tokens: 2, + max_blocks: 1, + }; + self.clone().prefill(batch).await?; + Ok(()) + } +} diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs new file mode 100644 index 00000000..8913e40b --- /dev/null +++ b/backends/v3/src/lib.rs @@ -0,0 +1,147 @@ +mod backend; +pub mod block_allocator; +mod client; +mod queue; +pub mod radix; + +use crate::client::{ClientError, ShardedClient}; +pub(crate) use backend::BackendV3; +use serde::Serialize; +use thiserror::Error; +use utoipa::ToSchema; + +#[derive(Clone, Debug, Serialize, ToSchema)] +pub struct BackendInfo { + /// Mandatory + #[schema(example = "cuda")] + pub model_device_type: String, + #[schema(example = "torch.float16")] + pub model_dtype: String, + + /// Backend parameters + #[schema(example = "1")] + pub speculate: usize, + #[schema(example = "1.2")] + pub waiting_served_ratio: f32, + #[schema(example = "32000")] + pub max_batch_total_tokens: u32, + #[schema(example = "20")] + pub max_waiting_tokens: usize, + #[schema(nullable = true, example = "null")] + pub max_batch_size: Option, +} + +#[allow(clippy::too_many_arguments)] +pub async fn connect_backend( + max_input_tokens: usize, + max_total_tokens: usize, + master_shard_uds_path: String, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: Option, + max_waiting_tokens: usize, + max_batch_size: Option, +) -> Result<(BackendV3, BackendInfo), V3Error> { + // Helper function + let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option| { + match max_supported_batch_total_tokens { + // Older models do not support automatic max-batch-total-tokens + None => { + let max_batch_total_tokens = max_batch_total_tokens + .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))); + tracing::warn!("Model does not support automatic max batch total tokens"); + Ok(max_batch_total_tokens) + } + // Flash attention models return their max supported total tokens + Some(max_supported_batch_total_tokens) => { + // Warn if user added his own max-batch-total-tokens as we will ignore it + if max_batch_total_tokens.is_some() { + tracing::warn!( + "`--max-batch-total-tokens` is deprecated for Flash \ + Attention models." + ); + tracing::warn!( + "Inferred max batch total tokens: {max_supported_batch_total_tokens}" + ); + } + if max_total_tokens as u32 > max_supported_batch_total_tokens { + return Err(V3Error::NotEnoughMemory(max_total_tokens)); + } + + Ok(max_supported_batch_total_tokens) + } + } + }; + + let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) + .await + .map_err(V3Error::Connection)?; + + // server is running on v3 + // Clear the cache; useful if the webserver rebooted + sharded_client + .clear_cache(None) + .await + .map_err(V3Error::Cache)?; + // Get info from the shard + let shard_info = sharded_client.info().await.map_err(V3Error::Info)?; + + // Warmup model + tracing::info!("Warming up model"); + let max_batch_total_tokens = check_max_batch_total_tokens( + sharded_client + .warmup( + max_input_tokens as u32, + max_batch_prefill_tokens, + max_total_tokens as u32, + max_batch_total_tokens.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))), + max_batch_size, + ) + .await + .map_err(V3Error::Warmup)?, + )?; + tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); + metrics::gauge!("tgi_batch_max_total_tokens").set(max_batch_total_tokens); + + let backend_info = BackendInfo { + waiting_served_ratio, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + model_device_type: shard_info.device_type.clone(), + model_dtype: shard_info.dtype.clone(), + speculate: shard_info.speculate as usize, + }; + + let backend = BackendV3::new( + sharded_client, + waiting_served_ratio, + max_input_tokens as u32, + max_total_tokens as u32, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + shard_info.requires_padding, + shard_info.window_size, + shard_info.speculate, + ); + + tracing::info!("Using backend V3"); + + Ok((backend, backend_info)) +} + +#[derive(Debug, Error)] +pub enum V3Error { + #[error("Unable to clear the Python model shards cache: {0}")] + Cache(ClientError), + #[error("Unable to connect to the Python model shards: {0}")] + Connection(ClientError), + #[error("Unable to get the Python model shards info: {0}")] + Info(ClientError), + #[error("Unable to warmup the Python model shards: {0}")] + Warmup(ClientError), + #[error("Not enough memory to handle `max_total_tokens={0}`")] + NotEnoughMemory(usize), +} diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs new file mode 100644 index 00000000..471ddb5a --- /dev/null +++ b/backends/v3/src/main.rs @@ -0,0 +1,212 @@ +use clap::{Parser, Subcommand}; +use text_generation_router::{server, usage_stats}; +use text_generation_router_v3::{connect_backend, V3Error}; +use thiserror::Error; + +/// App Configuration +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + #[command(subcommand)] + command: Option, + + #[clap(default_value = "128", long, env)] + max_concurrent_requests: usize, + #[clap(default_value = "2", long, env)] + max_best_of: usize, + #[clap(default_value = "4", long, env)] + max_stop_sequences: usize, + #[clap(default_value = "5", long, env)] + max_top_n_tokens: u32, + #[clap(default_value = "1024", long, env)] + max_input_tokens: usize, + #[clap(default_value = "2048", long, env)] + max_total_tokens: usize, + #[clap(default_value = "1.2", long, env)] + waiting_served_ratio: f32, + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, + #[clap(default_value = "20", long, env)] + max_waiting_tokens: usize, + #[clap(long, env)] + max_batch_size: Option, + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + #[clap(default_value = "3000", long, short, env)] + port: u16, + #[clap(default_value = "/tmp/text-generation-server-0", long, env)] + master_shard_uds_path: String, + #[clap(default_value = "bigscience/bloom", long, env)] + tokenizer_name: String, + #[clap(long, env)] + tokenizer_config_path: Option, + #[clap(long, env)] + revision: Option, + #[clap(default_value = "2", long, env)] + validation_workers: usize, + #[clap(long, env)] + api_key: Option, + #[clap(long, env)] + json_output: bool, + #[clap(long, env)] + otlp_endpoint: Option, + #[clap(default_value = "text-generation-inference.router", long, env)] + otlp_service_name: String, + #[clap(long, env)] + cors_allow_origin: Option>, + #[clap(long, env)] + ngrok: bool, + #[clap(long, env)] + ngrok_authtoken: Option, + #[clap(long, env)] + ngrok_edge: Option, + #[clap(long, env, default_value_t = false)] + messages_api_enabled: bool, + #[clap(long, env, default_value_t = false)] + disable_grammar_support: bool, + #[clap(default_value = "4", long, env)] + max_client_batch_size: usize, + #[clap(default_value = "on", long, env)] + usage_stats: usage_stats::UsageStatsLevel, +} + +#[derive(Debug, Subcommand)] +enum Commands { + PrintSchema, +} + +#[tokio::main] +async fn main() -> Result<(), RouterError> { + // Get args + let args = Args::parse(); + // Pattern match configuration + let Args { + command, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + hostname, + port, + master_shard_uds_path, + tokenizer_name, + tokenizer_config_path, + revision, + validation_workers, + api_key, + json_output, + otlp_endpoint, + otlp_service_name, + cors_allow_origin, + ngrok, + ngrok_authtoken, + ngrok_edge, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + usage_stats, + } = args; + + if let Some(Commands::PrintSchema) = command { + use utoipa::OpenApi; + let api_doc = text_generation_router::server::ApiDoc::openapi(); + let api_doc = serde_json::to_string_pretty(&api_doc).unwrap(); + println!("{}", api_doc); + std::process::exit(0); + }; + text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output); + + // Validate args + if max_input_tokens >= max_total_tokens { + return Err(RouterError::ArgumentValidation( + "`max_input_tokens` must be < `max_total_tokens`".to_string(), + )); + } + if max_input_tokens as u32 > max_batch_prefill_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); + } + + if validation_workers == 0 { + return Err(RouterError::ArgumentValidation( + "`validation_workers` must be > 0".to_string(), + )); + } + + if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { + if max_batch_prefill_tokens > *max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); + } + if max_total_tokens as u32 > *max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); + } + } + + if let Some(max_batch_size) = max_batch_size { + if max_batch_size == 0 { + return Err(RouterError::ArgumentValidation( + "`max_batch_size` must be > 0".to_string(), + )); + } + } + + let (backend, _backend_info) = connect_backend( + max_input_tokens, + max_total_tokens, + master_shard_uds_path, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + ) + .await?; + + // Run server + server::run( + backend, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + validation_workers, + api_key, + tokenizer_name, + tokenizer_config_path, + revision, + hostname, + port, + cors_allow_origin, + ngrok, + ngrok_authtoken, + ngrok_edge, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + usage_stats, + ) + .await?; + Ok(()) +} + +#[derive(Debug, Error)] +enum RouterError { + #[error("Argument validation error: {0}")] + ArgumentValidation(String), + #[error("Backend failed: {0}")] + Backend(#[from] V3Error), + #[error("WebServer error: {0}")] + WebServer(#[from] server::WebServerError), + #[error("Tokio runtime failed to start: {0}")] + Tokio(#[from] std::io::Error), +} diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs new file mode 100644 index 00000000..4ce54a79 --- /dev/null +++ b/backends/v3/src/queue.rs @@ -0,0 +1,824 @@ +use crate::block_allocator::{BlockAllocation, BlockAllocator}; +use crate::client; +use crate::client::{ + Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, +}; +use nohash_hasher::{BuildNoHashHasher, IntMap}; +use std::cmp::{max, min}; +use std::collections::VecDeque; +use text_generation_router::infer::InferError; +use text_generation_router::infer::InferStreamResponse; +use text_generation_router::validation::{ + Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, + ValidStoppingParameters, +}; +use tokio::sync::{mpsc, oneshot}; +use tokio::time::Instant; +use tracing::{info_span, instrument, Instrument, Span}; + +/// Queue entry +#[derive(Debug)] +pub(crate) struct Entry { + /// Request + pub request: ValidGenerateRequest, + /// Response sender to communicate between the Infer struct and the batching_task + pub response_tx: mpsc::UnboundedSender>, + /// Span that will live as long as entry + pub span: Span, + /// Temporary span used as a guard when logging inference, wait times... + pub temp_span: Option, + /// Instant when this entry was queued + pub queue_time: Instant, + /// Instant when this entry was added to a batch + pub batch_time: Option, + /// Block Allocation + pub block_allocation: Option, +} + +/// Request Queue +#[derive(Debug, Clone)] +pub(crate) struct Queue { + /// Channel to communicate with the background queue task + queue_sender: mpsc::UnboundedSender, +} + +impl Queue { + pub(crate) fn new( + requires_padding: bool, + block_size: u32, + prefix_caching: bool, + window_size: Option, + speculate: u32, + max_input_tokens: u32, + max_total_tokens: u32, + max_batch_total_tokens: u32, + ) -> Self { + // Create channel + let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); + + // Launch background queue task + tokio::spawn(queue_task( + requires_padding, + block_size, + prefix_caching, + window_size, + speculate, + max_input_tokens, + max_total_tokens, + max_batch_total_tokens, + queue_receiver, + )); + + Self { queue_sender } + } + + /// Append an entry to the queue + #[instrument(skip_all)] + pub(crate) fn append(&self, entry: Entry) { + // Send append command to the background task managing the state + // Unwrap is safe here + self.queue_sender + .send(QueueCommand::Append(Box::new(entry), Span::current())) + .unwrap(); + } + + // Get the next batch + #[instrument(skip(self))] + pub(crate) async fn next_batch( + &self, + min_size: Option, + max_size: Option, + prefill_token_budget: u32, + token_budget: u32, + ) -> Option { + // Create response channel + let (response_sender, response_receiver) = oneshot::channel(); + // Send next batch command to the background task managing the state + // Unwrap is safe here + self.queue_sender + .send(QueueCommand::NextBatch { + min_size, + max_size, + prefill_token_budget, + token_budget, + response_sender, + span: Span::current(), + }) + .unwrap(); + // Await on response channel + // Unwrap is safe here + response_receiver.await.unwrap() + } +} + +// Background task responsible of the queue state +async fn queue_task( + requires_padding: bool, + block_size: u32, + prefix_caching: bool, + window_size: Option, + speculate: u32, + max_input_tokens: u32, + max_total_tokens: u32, + max_batch_total_tokens: u32, + mut receiver: mpsc::UnboundedReceiver, +) { + let mut state = State::new( + requires_padding, + block_size, + prefix_caching, + window_size, + speculate, + max_input_tokens, + max_total_tokens, + max_batch_total_tokens, + ); + + while let Some(cmd) = receiver.recv().await { + match cmd { + QueueCommand::Append(entry, span) => { + span.in_scope(|| state.append(*entry)); + metrics::gauge!("tgi_queue_size").increment(1.0); + } + QueueCommand::NextBatch { + min_size, + max_size, + prefill_token_budget, + token_budget, + response_sender, + span, + } => { + let next_batch = state + .next_batch(min_size, max_size, prefill_token_budget, token_budget) + .instrument(span) + .await; + response_sender.send(next_batch).unwrap(); + metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); + } + } + } +} + +/// Queue State +#[derive(Debug)] +struct State { + /// Queue entries organized in a Vec + entries: VecDeque<(u64, Entry)>, + + /// Id of the next entry + next_id: u64, + + /// Id of the next batch + next_batch_id: u64, + + /// Paged Attention block size + block_size: u32, + + /// Sliding window + window_size: Option, + + /// Speculation amount + speculate: u32, + + /// Paged Attention Block Allocation + block_allocator: Option, + + /// Require padding + requires_padding: bool, + + /// max input tokens + max_input_tokens: u32, + + /// max total tokens, + max_total_tokens: u32, +} + +impl State { + fn new( + requires_padding: bool, + block_size: u32, + prefix_caching: bool, + window_size: Option, + speculate: u32, + max_input_tokens: u32, + max_total_tokens: u32, + max_batch_total_tokens: u32, + ) -> Self { + let block_allocator = (!requires_padding).then(|| { + BlockAllocator::new( + max_batch_total_tokens, + block_size, + prefix_caching, + window_size, + ) + }); + + Self { + entries: VecDeque::with_capacity(128), + next_id: 0, + next_batch_id: 0, + block_size, + window_size, + speculate, + block_allocator, + requires_padding, + max_input_tokens, + max_total_tokens, + } + } + + /// Append an entry to the queue + fn append(&mut self, mut entry: Entry) { + // Create a span that will live as long as the entry is in the queue waiting to be batched + let queue_span = info_span!(parent: &entry.span, "queued"); + entry.temp_span = Some(queue_span); + + // Push entry in the queue + self.entries.push_back((self.next_id, entry)); + self.next_id += 1; + } + + // Get the next batch + async fn next_batch( + &mut self, + min_size: Option, + max_size: Option, + prefill_token_budget: u32, + token_budget: u32, + ) -> Option { + if self.entries.is_empty() { + tracing::debug!("No queue"); + return None; + } + + // Check if we have enough entries + if let Some(min_size) = min_size { + if self.entries.len() < min_size { + tracing::debug!("Not enough entries"); + return None; + } + } + + if let Some(max_size) = max_size { + if max_size == 0 { + tracing::debug!("No capacity"); + return None; + } + } + + // Pad prefill_token_budget to be a multiple of block size + let prefill_token_budget = + ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size; + + // Create span for this batch to add context to inference calls + let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); + next_batch_span.follows_from(Span::current()); + + let mut batch = Vec::with_capacity(self.entries.len()); + let mut max_input_length = 0; + let mut prefill_tokens: u32 = 0; + let mut decode_tokens: u32 = 0; + let mut max_blocks = 0; + + // Pop entries starting from the front of the queue + 'entry_loop: while let Some((id, entry)) = self.entries.pop_front() { + // Filter entries where the response receiver was dropped (== entries where the request + // was dropped by the client) + if entry.response_tx.is_closed() { + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); + tracing::debug!("Dropping entry"); + continue; + } + + let block_allocation = match &self.block_allocator { + None => { + // We pad to max input length in the Python shards + // We need to take these padding tokens into the equation + if self.requires_padding { + prefill_tokens = (batch.len() + 1) as u32 * self.max_input_tokens; + } else{ + max_input_length = max_input_length.max(entry.request.input_length); + prefill_tokens = (batch.len() + 1) as u32 * max_input_length; + } + + if self.requires_padding { + decode_tokens = (batch.len() + 1) as u32 * (self.max_total_tokens - self.max_input_tokens); + } else { + decode_tokens += entry.request.stopping_parameters.max_new_tokens; + } + + let total_tokens = prefill_tokens + decode_tokens + self.speculate; + + if prefill_tokens > prefill_token_budget || total_tokens > token_budget { + // Entry is over budget + // Add it back to the front + tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); + self.entries.push_front((id, entry)); + break 'entry_loop; + } + None + } + Some(_block_allocator) => { + prefill_tokens += entry.request.input_length; + let max_new_tokens = match self.window_size { + None => entry.request.stopping_parameters.max_new_tokens, + Some(window_size) => min( + window_size.saturating_sub(entry.request.input_length), + entry.request.stopping_parameters.max_new_tokens, + ), + }; + decode_tokens += max_new_tokens; + + if prefill_tokens > prefill_token_budget + || (prefill_tokens + decode_tokens + self.speculate) > token_budget + { + // Entry is over budget + // Add it back to the front + tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); + self.entries.push_front((id, entry)); + break; + } + + let tokens = entry.request.input_length + + entry.request.stopping_parameters.max_new_tokens + + self.speculate + - 1; + + // If users wants the prefill logprobs, we cannot reuse the cache. + // So no input_ids for the radix tree. + let input_ids = if entry.request.decoder_input_details { + None + } else { + entry.request.input_ids.clone() + }; + + Some((tokens, input_ids)) + } + }; + batch.push((id, entry, block_allocation)); + if Some(batch.len()) == max_size { + break; + } + } + + // Empty batch + if batch.is_empty() { + tracing::debug!("Filterered out all entries"); + return None; + } + + // XXX We haven't allocated yet, so we're allowed to ditch the results. + // Check if our batch is big enough + if let Some(min_size) = min_size { + // Batch is too small + if batch.len() < min_size { + // Add back entries to the queue in the correct order + for (id, entry, _) in batch.into_iter().rev() { + self.entries.push_front((id, entry)); + } + return None; + } + } + + let mut batch_requests = Vec::with_capacity(self.entries.len()); + let mut batch_entries = + IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); + + for (id, mut entry, block_allocation) in batch { + let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) = + (block_allocation, &self.block_allocator) + { + tracing::debug!("Allocating {tokens} with {input_ids:?}"); + match block_allocator.allocate(tokens, input_ids).await { + None => { + // Entry is over budget + // Add it back to the front + tracing::debug!("Over budget: not enough free blocks"); + self.entries.push_front((id, entry)); + continue; + } + Some(block_allocation) => { + tracing::debug!("Allocation: {block_allocation:?}"); + max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); + Some(block_allocation) + } + } + } else { + None + }; + tracing::debug!("Accepting entry"); + // Create a new span to link the batch back to this entry + let entry_batch_span = info_span!(parent: &entry.span, "infer"); + // Add relationships + next_batch_span.follows_from(&entry_batch_span); + entry_batch_span.follows_from(&next_batch_span); + // Update entry + entry.temp_span = Some(entry_batch_span); + + let (blocks, slots, prefix_len) = match &block_allocation { + None => (Vec::new(), Vec::new(), 0), + Some(block_allocation) => ( + block_allocation.blocks.clone(), + block_allocation.slots.clone(), + block_allocation.prefix_len, + ), + }; + + entry.block_allocation = block_allocation; + + batch_requests.push(Request { + id, + prefill_logprobs: entry.request.decoder_input_details, + input_chunks: Some(client::Input { + chunks: entry + .request + .inputs + .clone() + .into_iter() + .map(|c| client::InputChunk { + chunk: Some(match c { + Chunk::Text(text) => client::Chunk::Text(text), + Chunk::Image(image) => client::Chunk::Image(client::Image { + data: image.data, + mimetype: image.mimetype, + }), + }), + }) + .collect(), + }), + inputs: entry.request.inputs.chunks_to_string(), + truncate: entry.request.truncate, + add_special_tokens: entry.request.add_special_tokens, + parameters: Some(NextTokenChooserParameters::from( + entry.request.parameters.clone(), + )), + stopping_parameters: Some(StoppingCriteriaParameters::from( + entry.request.stopping_parameters.clone(), + )), + top_n_tokens: entry.request.top_n_tokens, + blocks, + slots, + prefix_len, + adapter_id: entry.request.adapter_id.clone(), + }); + // Set batch_time + entry.batch_time = Some(Instant::now()); + // Insert in batch_entries IntMap + batch_entries.insert(id, entry); + } + + // Empty batch + if batch_requests.is_empty() { + tracing::debug!("Filterered out all entries"); + return None; + } + + // Final batch size + let size = batch_requests.len() as u32; + next_batch_span.record("batch_size", size); + + let batch = Batch { + id: self.next_batch_id, + requests: batch_requests, + size, + max_tokens: (prefill_tokens + decode_tokens), + max_blocks, + }; + // Increment batch id + self.next_batch_id += 1; + + metrics::histogram!("tgi_batch_next_size").record(batch.size as f64); + + Some((batch_entries, batch, next_batch_span)) + } +} + +type NextBatch = (IntMap, Batch, Span); + +#[derive(Debug)] +enum QueueCommand { + Append(Box, Span), + NextBatch { + min_size: Option, + max_size: Option, + prefill_token_budget: u32, + token_budget: u32, + response_sender: oneshot::Sender>, + span: Span, + }, +} + +impl From for NextTokenChooserParameters { + fn from(value: ValidParameters) -> Self { + let (grammar, grammar_type) = match value.grammar { + None => (String::new(), GrammarType::None), + + Some(grammar) => match grammar { + ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json), + ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex), + }, + }; + + Self { + temperature: value.temperature, + top_k: value.top_k, + top_p: value.top_p, + typical_p: value.typical_p, + do_sample: value.do_sample, + seed: value.seed, + repetition_penalty: value.repetition_penalty, + frequency_penalty: value.frequency_penalty, + watermark: value.watermark, + grammar, + grammar_type: grammar_type.into(), + } + } +} + +impl From for StoppingCriteriaParameters { + fn from(value: ValidStoppingParameters) -> Self { + Self { + max_new_tokens: value.max_new_tokens, + stop_sequences: value.stop_sequences, + ignore_eos_token: value.ignore_eos_token, + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + use tracing::info_span; + + fn default_entry() -> ( + Entry, + mpsc::UnboundedReceiver>, + ) { + let (response_tx, receiver_tx) = mpsc::unbounded_channel(); + + let entry = Entry { + request: ValidGenerateRequest { + inputs: vec![], + input_ids: Some(Arc::new(vec![])), + input_length: 0, + add_special_tokens: true, + truncate: 0, + decoder_input_details: false, + parameters: ValidParameters { + temperature: 0.0, + top_k: 0, + top_p: 0.0, + typical_p: 0.0, + do_sample: false, + seed: 0, + repetition_penalty: 0.0, + frequency_penalty: 0.0, + watermark: false, + grammar: None, + }, + stopping_parameters: ValidStoppingParameters { + ignore_eos_token: false, + max_new_tokens: 1, + stop_sequences: vec![], + }, + top_n_tokens: 0, + adapter_id: None, + }, + response_tx, + span: info_span!("entry"), + temp_span: None, + queue_time: Instant::now(), + batch_time: None, + block_allocation: None, + }; + (entry, receiver_tx) + } + + #[tokio::test] + async fn test_append() { + let mut state = State::new(false, 1, false, None, 0, 16); + let (entry, _guard) = default_entry(); + + assert_eq!(state.next_id, 0); + assert_eq!(state.entries.len(), 0); + + state.append(entry); + + assert_eq!(state.next_id, 1); + assert_eq!(state.entries.len(), 1); + let (id, _) = state.entries.remove(0).unwrap(); + assert_eq!(id, 0); + } + + #[tokio::test] + async fn test_next_batch_empty() { + let mut state = State::new(false, 1, false, None, 0, 16); + + assert!(state.next_batch(None, None, 1, 1).await.is_none()); + assert!(state.next_batch(Some(1), None, 1, 1).await.is_none()); + } + + #[tokio::test] + async fn test_next_batch_min_size() { + let mut state = State::new(false, 1, false, None, 0, 16); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + state.append(entry1); + state.append(entry2); + + let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap(); + assert_eq!(entries.len(), 2); + assert!(entries.contains_key(&0)); + assert!(entries.contains_key(&1)); + assert!(entries.get(&0).unwrap().batch_time.is_some()); + assert!(entries.get(&1).unwrap().batch_time.is_some()); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 2); + + assert_eq!(state.next_id, 2); + assert_eq!(state.entries.len(), 0); + assert_eq!(state.next_batch_id, 1); + + let (entry3, _guard3) = default_entry(); + state.append(entry3); + + assert!(state.next_batch(Some(2), None, 2, 2).await.is_none()); + + assert_eq!(state.next_id, 3); + assert_eq!(state.entries.len(), 1); + let (id, _) = state.entries.remove(0).unwrap(); + assert_eq!(id, 2); + } + + #[tokio::test] + async fn test_next_batch_max_size() { + let mut state = State::new(false, 1, false, None, 0, 16); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + state.append(entry1); + state.append(entry2); + + let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap(); + assert_eq!(entries.len(), 1); + assert!(entries.contains_key(&0)); + assert!(entries.get(&0).unwrap().batch_time.is_some()); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 1); + + assert_eq!(state.next_id, 2); + assert_eq!(state.entries.len(), 1); + assert_eq!(state.next_batch_id, 1); + } + + #[tokio::test] + async fn test_next_batch_token_budget() { + let mut state = State::new(false, 1, false, None, 0, 2); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + state.append(entry1); + state.append(entry2); + + let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap(); + assert_eq!(entries.len(), 1); + assert!(entries.contains_key(&0)); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 1); + + assert_eq!(state.next_id, 2); + assert_eq!(state.entries.len(), 1); + assert_eq!(state.next_batch_id, 1); + + let (entry3, _guard3) = default_entry(); + state.append(entry3); + + let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap(); + assert_eq!(entries.len(), 2); + assert!(entries.contains_key(&1)); + assert!(entries.contains_key(&2)); + assert_eq!(batch.id, 1); + assert_eq!(batch.size, 2); + + assert_eq!(state.next_id, 3); + assert_eq!(state.entries.len(), 0); + assert_eq!(state.next_batch_id, 2); + } + + #[tokio::test] + async fn test_queue_append() { + let queue = Queue::new(false, 1, false, None, 0, 16); + let (entry, _guard) = default_entry(); + queue.append(entry); + } + + #[tokio::test] + async fn test_queue_next_batch_empty() { + let queue = Queue::new(false, 1, false, None, 0, 16); + + assert!(queue.next_batch(None, None, 1, 1).await.is_none()); + assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); + } + + #[tokio::test] + async fn test_queue_next_batch_min_size() { + let queue = Queue::new(false, 1, false, None, 0, 16); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + queue.append(entry1); + queue.append(entry2); + + let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap(); + assert_eq!(entries.len(), 2); + assert!(entries.contains_key(&0)); + assert!(entries.contains_key(&1)); + assert!(entries.get(&0).unwrap().batch_time.is_some()); + assert!(entries.get(&1).unwrap().batch_time.is_some()); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 2); + + let (entry3, _guard3) = default_entry(); + queue.append(entry3); + + // Not enough requests pending + assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none()); + // Not enough token budget + assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none()); + // Ok + let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap(); + assert_eq!(entries2.len(), 1); + assert!(entries2.contains_key(&2)); + assert!(entries2.get(&2).unwrap().batch_time.is_some()); + assert_eq!(batch2.id, 1); + assert_eq!(batch2.size, 1); + } + + #[tokio::test] + async fn test_queue_next_batch_max_size() { + let queue = Queue::new(false, 1, false, None, 0, 16); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + queue.append(entry1); + queue.append(entry2); + + let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap(); + assert_eq!(entries.len(), 1); + assert!(entries.contains_key(&0)); + assert!(entries.get(&0).unwrap().batch_time.is_some()); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 1); + } + + #[tokio::test] + async fn test_queue_next_batch_token_budget() { + let queue = Queue::new(false, 1, false, None, 0, 16); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + queue.append(entry1); + queue.append(entry2); + + let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap(); + assert_eq!(entries.len(), 1); + assert!(entries.contains_key(&0)); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 1); + + let (entry3, _guard3) = default_entry(); + queue.append(entry3); + + let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap(); + assert_eq!(entries.len(), 2); + assert!(entries.contains_key(&1)); + assert!(entries.contains_key(&2)); + assert_eq!(batch.id, 1); + assert_eq!(batch.size, 2); + } + + #[tokio::test] + async fn test_queue_next_batch_token_speculate() { + let queue = Queue::new(false, 1, false, None, 2, 16); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + queue.append(entry1); + queue.append(entry2); + + // Budget of 1 is not enough + assert!(queue.next_batch(None, None, 1, 1).await.is_none()); + + let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap(); + assert_eq!(entries.len(), 2); + assert!(entries.contains_key(&0)); + assert!(entries.contains_key(&1)); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 2); + } + + #[tokio::test] + async fn test_queue_next_batch_dropped_receiver() { + let queue = Queue::new(false, 1, false, None, 0, 16); + let (entry, _) = default_entry(); + queue.append(entry); + + assert!(queue.next_batch(None, None, 1, 1).await.is_none()); + } +} diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs new file mode 100644 index 00000000..8a544891 --- /dev/null +++ b/backends/v3/src/radix.rs @@ -0,0 +1,876 @@ +use crate::block_allocator::{Allocator, BlockAllocation}; +use slotmap::{DefaultKey, SlotMap}; +use std::hash::{Hash, Hasher}; +use std::{ + collections::{BTreeSet, HashMap}, + sync::Arc, +}; + +fn hash(slice: &[u32]) -> u64 { + assert!(!slice.is_empty()); + if slice.len() == 1 { + slice[0] as u64 + } else { + let mut s = std::hash::DefaultHasher::new(); + slice.hash(&mut s); + s.finish() + } +} + +pub struct RadixAllocator { + allocation_id: u64, + + allocations: HashMap, + + cache_blocks: RadixTrie, + + /// Blocks that are immediately available for allocation. + free_blocks: Vec, + + #[allow(dead_code)] + // This isn't used because the prefix need to match without the windowing + // mecanism. This at worst is overallocating, not necessarily being wrong. + window_size: Option, + + block_size: u32, +} + +impl RadixAllocator { + pub fn new(block_size: u32, n_blocks: u32, window_size: Option) -> Self { + RadixAllocator { + allocation_id: 0, + allocations: HashMap::new(), + cache_blocks: RadixTrie::new(block_size as usize), + + // Block 0 is reserved for health checks. + free_blocks: (1..n_blocks).collect(), + window_size, + block_size, + } + } + + fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option> { + if self.free_blocks.len() < n_blocks_needed { + // This is a bit annoying, we first extend the free list and then + // split it off again below. This is because we need to put it on + // the free list if we cannot allocate enough blocks. This is only + // temporary, the trie needs to be able to report whether it can + // allocate the requested amount. Just not implemented yet. + tracing::debug!( + "Free blocks {} need {n_blocks_needed}", + self.free_blocks.len() + ); + self.free_blocks.extend( + self.cache_blocks + .evict(n_blocks_needed - self.free_blocks.len()), + ); + } + + if self.free_blocks.len() >= n_blocks_needed { + Some( + self.free_blocks + .split_off(self.free_blocks.len() - n_blocks_needed), + ) + } else { + None + } + } +} + +// Allocator trait +impl Allocator for RadixAllocator { + fn allocate( + &mut self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option { + let mut blocks = vec![]; + let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() { + let node_id = self + .cache_blocks + .find(prefill_tokens.as_slice(), &mut blocks); + node_id + } else { + self.cache_blocks.root_id() + }; + + // Even if this allocation fails below, we need to increase he + // refcount to ensure that the prefix that was found is not evicted. + self.cache_blocks + .incref(prefix_node) + .expect("Failed to increment refcount"); + + let prefix_len = blocks.len() * self.block_size as usize; + let suffix_len = tokens - prefix_len as u32; + + let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size; + + tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}"); + + match self.alloc_or_reclaim(suffix_blocks as usize) { + Some(suffix_blocks) => blocks.extend(suffix_blocks), + None => { + tracing::debug!("Cannot allocate {:?}", self.cache_blocks); + tracing::debug!("Found {prefix_len} prefix tokens need {suffix_blocks} suffix blocks for {tokens} tokens"); + tracing::debug!("Block size {}", self.block_size); + self.cache_blocks + .decref(prefix_node) + .expect("Failed to decrement refcount"); + return None; + } + } + + // 1:1 mapping of blocks and slots. + let slots = if self.block_size == 1 { + blocks.clone() + } else { + let mut slots = Vec::with_capacity(blocks.len() * self.block_size as usize); + 'slots: for block_id in &blocks { + for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { + slots.push(s); + if slots.len() as u32 == tokens { + break 'slots; + } + } + } + slots + }; + + let allocation = RadixAllocation { + prefix_node, + cached_prefix_len: prefix_len, + prefill_tokens: prefill_tokens.clone(), + }; + + self.allocation_id += 1; + self.allocations.insert(self.allocation_id, allocation); + + Some(BlockAllocation { + allocation_id: self.allocation_id, + block_allocator: None, + blocks, + slots, + prefix_len: prefix_len as u32, + }) + } + + fn free(&mut self, blocks: Vec, allocation_id: u64) { + let allocation = match self.allocations.remove(&allocation_id) { + Some(allocation) => allocation, + None => unreachable!("Tried to free an unknown allocation."), + }; + + self.cache_blocks + .decref(allocation.prefix_node) + .expect("Failed to decrement refcount"); + + if let Some(prefill_tokens) = allocation.prefill_tokens { + let prefill_tokens = prefill_tokens.as_slice(); + + // If there are prefill tokens that did not come from the cache, + // add them to the cache. + if prefill_tokens.len() > allocation.cached_prefix_len { + let aligned = + (prefill_tokens.len() / self.block_size as usize) * self.block_size as usize; + if aligned > 0 { + let prefix_len = self + .cache_blocks + .insert( + &prefill_tokens[..aligned], + &blocks[..aligned / self.block_size as usize], + ) + // Unwrap, failing is a programming error. + .expect("Failed to store prefill tokens"); + // We can have a prefill with the following structure: + // + // |---| From the prefix cache. + // A B C D E F G + //|--------| Found in the trie during insertion. + // + // This means that while processing this request there was a + // partially overlapping request that had A..=E in its + // prefill. In this case we need to free the blocks D E. + if prefix_len > allocation.cached_prefix_len { + self.free_blocks.extend( + &blocks[allocation.cached_prefix_len / self.block_size as usize + ..prefix_len / self.block_size as usize], + ); + } + } + } + + // Free non-prefill blocks. + self.free_blocks + .extend(&blocks[prefill_tokens.len() / self.block_size as usize..]); + } else { + self.free_blocks.extend(blocks); + } + } +} + +struct RadixAllocation { + prefix_node: NodeId, + cached_prefix_len: usize, + prefill_tokens: Option>>, +} + +// Radix trie that is heavily inspired by radix attention from sglang. +// +// The trie is optimized for prefix caching: +// +// - A normal radix trie stores discrete values. In this radix trie, +// inserting *abc* with value *xyz* will also enable lookup for +// *a* (*x*) and *ab* (*xy*). +// - As a result, every value is required to have the same length as +// the key. +// - We store additional information in each node, such as last access +// time and a reference count. + +#[derive(Debug)] +pub enum TrieError { + InvalidNodeId, + RefCountUnderflow, +} + +pub type NodeId = DefaultKey; + +#[derive(Debug)] +pub struct RadixTrie { + /// Identifier of the root nod. + root: DefaultKey, + + /// Leave node identifiers ordered by increasing recency. + leaves: BTreeSet<(u64, NodeId)>, + + /// All trie nodes. + nodes: SlotMap, + + /// Time as a monotonically increating counter to avoid the system + /// call that a real time lookup would require. + time: u64, + + /// All blocks need to be aligned with this + block_size: usize, +} + +impl RadixTrie { + /// Construct a new radix trie. + pub fn new(block_size: usize) -> Self { + let root = TrieNode::new(vec![], vec![], 0, None); + let mut nodes = SlotMap::new(); + let root = nodes.insert(root); + RadixTrie { + leaves: BTreeSet::new(), + nodes, + root, + time: 0, + block_size, + } + } + + /// Find the prefix of the given tokens. + /// + /// The blocks corresponding to the part of the prefix that could be found + /// are written to `blocks`. The number of blocks is in `0..=tokens.len()`. + /// Returns the identifier of the trie node that contains the longest + /// prefix. The node identifier can be used by callers to e.g. increase its + /// reference count. + /// + /// Using this method will update the access time of the traversed nodes. + pub fn find(&mut self, key: &[u32], blocks: &mut Vec) -> NodeId { + self.time += 1; + self.find_(self.root, key, blocks) + } + + /// Find worker. + fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec) -> NodeId { + let node = &self.nodes[node_id]; + + if key.len() >= self.block_size { + let node_key = hash(&key[..self.block_size]); + if let Some(&child_id) = node.children.get(&node_key) { + self.update_access_time(child_id); + let child = self.nodes.get(child_id).expect("Invalid child identifier"); + let shared_prefix_len = shared_prefix(&child.key, key, self.block_size); + assert_eq!(shared_prefix_len % self.block_size, 0); + blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]); + + let key = &key[shared_prefix_len..]; + if !key.is_empty() { + node_id = self.find_(child_id, key, blocks); + } + } + } + + node_id + } + + /// Decrease the reference count of a node. + pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> { + // We don't care about refcounting for root, since it will never + // be evicted. + if node_id == self.root { + return Ok(()); + } + + let node = self + .nodes + .get_mut(node_id) + .ok_or(TrieError::InvalidNodeId)?; + if node.ref_count == 0 { + return Err(TrieError::RefCountUnderflow); + } + + node.ref_count -= 1; + if node.ref_count == 0 { + assert!( + node.children.is_empty(), + "Nodes with children must have refcount > 0" + ); + + self.leaves.insert((node.last_accessed, node_id)); + } + + Ok(()) + } + + /// Increase the reference count of a node. + pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> { + if node_id == self.root { + return Ok(()); + } + + let node = self + .nodes + .get_mut(node_id) + .ok_or(TrieError::InvalidNodeId)?; + if node.ref_count == 0 { + self.leaves.remove(&(node.last_accessed, node_id)); + } + node.ref_count += 1; + + Ok(()) + } + + /// Evict `n_blocks` from the trie. + /// + /// Returns the evicted blocks. When the length is less than `n_blocks`, + /// not enough blocks could be evicted. + pub fn evict(&mut self, n_blocks: usize) -> Vec { + // NOTE: we don't return Result here. If any of the unwrapping fails, + // it's a programming error in the trie implementation, not a user + // error caused by e.g. an invalid argument. + + // TODO: add some bookkeeping in the future to check whether we can + // evict n_blocks and return `None` if we can't. We are now needlessly + // evicting prefixes from the cache in such a case. + let mut evicted = Vec::new(); + tracing::debug!("Evicting in search of {n_blocks}"); + + while let Some((last_access, node_id)) = self.leaves.pop_first() { + let blocks_needed = n_blocks.saturating_sub(evicted.len()); + tracing::debug!("Evicting node {node_id:?} "); + + let node = self.nodes.get(node_id).expect("Leave does not exist"); + assert_eq!( + node.ref_count, 0, + "Leaf must have refcount of 0, got {}", + node.ref_count + ); + + if blocks_needed >= node.blocks.len() { + // We need to evict the whole node if we need more blocks than it has. + let node = self.remove_node(node_id); + evicted.extend(node.blocks); + + if evicted.len() >= n_blocks { + break; + } + } else { + // The node has more blocks than needed, so we'll just remove + // the required number of blocks and leave the remaining blocks + // untouched. + let node = self.nodes.get_mut(node_id).expect("Leave does not exist"); + + let truncate_blocks = node.blocks.len() - blocks_needed; + let truncate_tokens = truncate_blocks * self.block_size; + node.key.truncate(truncate_tokens); + evicted.extend(node.blocks.split_off(truncate_blocks)); + self.leaves.insert((last_access, node_id)); + break; + } + } + + evicted + } + + /// Insert a prefill along with its blocks. + /// + /// This method returns the length of the prefix that was already + /// in the trie. E.g. if the length is 10, this means that for + /// the first 10 elements of the tree **the blocks are not updated**. + pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result { + self.time += 1; + let common = self.insert_(self.root, tokens, blocks)?; + Ok(common) + } + + /// Insertion worker. + fn insert_( + &mut self, + node_id: NodeId, + tokens: &[u32], + blocks: &[u32], + ) -> Result { + // TODO: in the future we may want to check that the blocks match for + // the part of the prefix that is already in the trie to detect + // mismatches. + + assert_eq!(tokens.len(), blocks.len() * self.block_size); + + let node_key = hash(&tokens[..self.block_size]); + if let Some(&child_id) = self.nodes[node_id].children.get(&node_key) { + self.update_access_time(child_id); + let child = self + .nodes + .get_mut(child_id) + // Unwrap here, since failure is a bug. + .expect("Child node does not exist"); + let shared_prefix_len = shared_prefix(&child.key, tokens, self.block_size); + + // We are done, the prefix is already in the trie. + if shared_prefix_len == tokens.len() || shared_prefix_len == 0 { + return Ok(shared_prefix_len); + } + + // The node's prefix is a prefix of the insertion prefix. + if shared_prefix_len == child.key.len() { + return Ok(shared_prefix_len + + self.insert_( + child_id, + &tokens[shared_prefix_len..], + &blocks[shared_prefix_len / self.block_size..], + )?); + } + + // The node's prefix and the insertion prefix only match partially, + // split the node to just contain the matching part. Then insert the + // remainder of the prefix into the node again + let child_id = self.split_node(child_id, shared_prefix_len); + let key = &tokens[shared_prefix_len..]; + let blocks = &blocks[shared_prefix_len / self.block_size..]; + Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?) + } else { + self.add_node(node_id, tokens, blocks); + Ok(0) + } + } + + fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId { + // We have to make the current node a child to ensure that its + // properties and node id stay the same. + + // This funcion unwraps, an invalid node_id is a programming error. + + let node = self + .nodes + .get_mut(node_id) + .expect("Node to-be split does not exist"); + let mut parent_key = node.key.split_off(prefix_len); + let prefix_blocks = prefix_len / self.block_size; + let mut parent_blocks = node.blocks.split_off(prefix_blocks); + + // Move first part of the prefix to the parent. We swap to avoid + // an allocation + copy for both splits of the key/blocks. + std::mem::swap(&mut node.key, &mut parent_key); + std::mem::swap(&mut node.blocks, &mut parent_blocks); + + let node_key = hash(&node.key[..self.block_size]); + + let grandparent_id = node.parent.expect("Node does not have a parent"); + let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks); + self.add_node_to_parent(parent_id, node_key, node_id); + + // Reborrow to make the borrow checker happy. + let node = self + .nodes + .get_mut(node_id) + .expect("Node to-be split does not exist"); + node.parent = Some(parent_id); + + parent_id + } + + /// Create a node and add it to the parent. + fn add_node( + &mut self, + parent_id: NodeId, + key: impl Into>, + blocks: impl Into>, + ) -> NodeId { + let key = key.into(); + let blocks = blocks.into(); + let first = hash(&key[..self.block_size]); + + let child = TrieNode::new(key, blocks, self.time, Some(parent_id)); + let child_id = self.nodes.insert(child); + + self.add_node_to_parent(parent_id, first, child_id); + self.leaves.insert((self.time, child_id)); + + child_id + } + + /// Add a node to the parent. + fn add_node_to_parent(&mut self, parent_id: NodeId, hash: u64, child_id: NodeId) { + // Unwrap here, passing in an unknown id is a programming error. + let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); + if parent.children.insert(hash, child_id).is_none() { + // Only increase reference count if child does not replace another child. + self.incref(parent_id) + .expect("Failed to increase parent refcount"); + } + } + + /// Remove a node from the trie. + fn remove_node(&mut self, node_id: NodeId) -> TrieNode { + // Unwrap here, passing in an unknown id is a programming error. + let node = self.nodes.remove(node_id).expect("Unknown node"); + assert!( + node.children.is_empty(), + "Tried to remove a node with {} children", + node.children.len() + ); + let parent_id = node.parent.expect("Attempted to remove root node"); + let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); + + let node_key = hash(&node.key[..self.block_size]); + parent.children.remove(&node_key); + self.decref(parent_id) + .expect("Failed to decrease parent refcount"); + node + } + + fn update_access_time(&mut self, node_id: NodeId) { + // Unwrap here, passing in an unknown id is a programming error. + let node = self.nodes.get_mut(node_id).expect("Unknown node"); + + // Update the ordered leaves set if the node is a leave. + if self.leaves.remove(&(node.last_accessed, node_id)) { + self.leaves.insert((self.time, node_id)); + } + + node.last_accessed = self.time; + } + + #[allow(dead_code)] + #[doc(hidden)] + /// Print debugging output for the trie. + /// + /// In contrast to `Debug` nicely formatted. + pub fn print_debug(&self) { + self.print_debug_(self.root, 0); + } + + fn print_debug_(&self, node_id: NodeId, indent: usize) { + let node = &self.nodes[node_id]; + eprintln!( + "{}{:?}, key: {:?}, blocks: {:?}, ref_count: {}, last_accessed: {}, parent: {:?}, children: {:?}", + " ".repeat(indent), + node_id, + node.key, + node.blocks, + node.ref_count, + node.last_accessed, + node.parent, + node.children + ); + for child_id in self.nodes[node_id].children.values() { + self.print_debug_(*child_id, indent + 2); + } + } + + pub(crate) fn root_id(&self) -> DefaultKey { + self.root + } +} + +/// Trie node. +#[derive(Debug)] +struct TrieNode { + blocks: Vec, + children: HashMap, + key: Vec, + last_accessed: u64, + parent: Option, + ref_count: usize, +} + +impl TrieNode { + fn new(key: Vec, blocks: Vec, last_accessed: u64, parent: Option) -> Self { + TrieNode { + children: HashMap::new(), + key, + blocks, + last_accessed, + parent, + ref_count: 0, + } + } +} + +fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize { + let full = left.iter().zip(right).take_while(|(a, b)| a == b).count(); + // NOTE: this is the case because the child node was chosen based on + // matching the first character of the key/prefix. + assert!(full > 0, "Prefixes must at least share 1 token"); + (full / block_size) * block_size +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + + #[test] + fn allocator_block_size() { + let mut cache = RadixAllocator::new(2, 12, None); + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]); + assert_eq!(allocation.prefix_len, 4); + } + + #[test] + fn allocator_block_size_non_aligned() { + let mut cache = RadixAllocator::new(2, 12, None); + let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); + assert_eq!(allocation.prefix_len, 2); + } + + #[test] + fn allocator_reuses_prefixes() { + let mut cache = RadixAllocator::new(1, 12, None); + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation.blocks, allocation.slots); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation.prefix_len, 4); + } + + #[test] + fn allocator_collects_older_prefixes_first() { + let mut cache = RadixAllocator::new(1, 7, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]); + assert_eq!(allocation1.prefix_len, 0); + + let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap(); + assert_eq!(allocation2.blocks, vec![1, 2]); + assert_eq!(allocation2.prefix_len, 0); + + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + + // We should get the blocks of the first allocation, since they are more recent. + let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap(); + assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]); + assert_eq!(allocation3.prefix_len, 0); + } + + #[test] + fn allocator_frees_fully_overlapping_prefills() { + let mut cache = RadixAllocator::new(1, 10, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + + let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation3.prefix_len, 4); + + // 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks. + assert_eq!(cache.free_blocks.len(), 5); + } + + #[test] + fn allocator_frees_partially_overlapping_prefills() { + let mut cache = RadixAllocator::new(1, 20, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap(); + assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]); + assert_eq!(allocation1.prefix_len, 0); + + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + + let allocation2 = cache + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) + .unwrap(); + assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]); + assert_eq!(allocation2.prefix_len, 2); + + let allocation3 = cache + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) + .unwrap(); + assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation3.prefix_len, 2); + + cache.free(allocation3.blocks.clone(), allocation3.allocation_id); + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + + // 20 blocks, of which 1 reserved for health checks, 6 for allocation3, 2 for allocation2. + assert_eq!(cache.free_blocks.len(), 11); + + let allocation4 = cache + .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) + .unwrap(); + assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]); + assert_eq!(allocation4.prefix_len, 6); + assert_eq!(cache.free_blocks.len(), 11); + + let allocation5 = cache + .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) + .unwrap(); + assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]); + assert_eq!(allocation5.prefix_len, 6); + assert_eq!(cache.free_blocks.len(), 11); + } + + #[test] + fn trie_insertions_have_correct_prefix_len() { + let mut trie = RadixTrie::new(1); + + assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0); + + // Already exists. + assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 3); + + // Completely new at root-level + assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(), 0); + + // Contains full prefix, but longer. + assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(), 3); + + // Shares partial prefix, we need a split. + assert_eq!( + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(), + 4 + ); + } + + #[test] + fn trie_insertions_block_size() { + let mut trie = RadixTrie::new(2); + + assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 0); + + // Already exists. + // But needs to be block_size aligned + assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 4); + + // Completely new at root-level + assert_eq!(trie.insert(&[1, 2, 3, 4], &[1, 2]).unwrap(), 0); + + // Contains full prefix, but longer. + assert_eq!(trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2]).unwrap(), 4); + + // Shares partial prefix, we need a split. + assert_eq!( + trie.insert(&[0, 1, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3]) + .unwrap(), + 2 + ); + } + + #[test] + fn trie_get_returns_correct_blocks() { + let mut trie = RadixTrie::new(1); + trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); + trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); + trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(); + + let mut blocks = Vec::new(); + trie.find(&[0], &mut blocks); + assert_eq!(blocks, vec![0]); + + blocks.clear(); + trie.find(&[0, 1, 2], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2]); + + blocks.clear(); + trie.find(&[1, 2, 3], &mut blocks); + assert_eq!(blocks, vec![1, 2, 3]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 4]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3, 5], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 5]); + } + + #[test] + fn trie_evict_removes_correct_blocks() { + let mut trie = RadixTrie::new(1); + trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(); + trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); + trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); + + let mut blocks = Vec::new(); + + // Remove less than the leave blocks. + assert_eq!(trie.evict(1), vec![7]); + trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]); + + // Refresh other leaf. + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + trie.find(&[1, 2, 3], &mut blocks); + + // Remove the leave blocks exactly. + assert_eq!(trie.evict(2), vec![5, 6]); + blocks.clear(); + trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3]); + + trie.find(&[1, 2, 3], &mut blocks); + + // Remove more than the leave blocks. + assert_eq!(trie.evict(3), vec![4, 3, 2]); + blocks.clear(); + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + assert_eq!(blocks, vec![0, 1]); + + // Clear out the whole trie. + assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]); + } +} diff --git a/benchmark/Cargo.toml b/benchmark/Cargo.toml index 756460e0..23885da2 100644 --- a/benchmark/Cargo.toml +++ b/benchmark/Cargo.toml @@ -16,16 +16,15 @@ path = "src/main.rs" [dependencies] average = "0.14" clap = { version = "4.4.5", features = ["derive", "env"] } -crossterm = "0.27" float-ord = "0.3.2" serde = {version = "1.0.188", features = ["derive"]} serde_json = "1.0" tabled = "0.14.0" -text-generation-client = { path = "../router/client" } +text-generation-client = { path = "../backends/client" } thiserror = "1.0.48" tokenizers = { workspace = true } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] } -tui = {package = "ratatui", version = "0.23", default-features = false, features = ["crossterm"]} +ratatui = "0.28.1" tracing = "0.1.37" tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } hf-hub = { workspace = true } diff --git a/benchmark/README.md b/benchmark/README.md index 17a02a30..f4e0cb16 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -7,7 +7,7 @@ A lightweight benchmarking tool based inspired by [oha](https://github.com/hatoo/oha) -and powered by [tui](https://github.com/tui-rs-revival/ratatui). +and powered by [Ratatui](https://github.com/ratatui/ratatui). ## Install diff --git a/benchmark/src/app.rs b/benchmark/src/app.rs index 48ac976a..7e3aeaf9 100644 --- a/benchmark/src/app.rs +++ b/benchmark/src/app.rs @@ -1,16 +1,15 @@ /// Inspired by https://github.com/hatoo/oha/blob/bb989ea3cd77727e7743e7daa60a19894bb5e901/src/monitor.rs use crate::generation::{Decode, Message, Prefill}; -use crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; -use text_generation_client::ClientError; -use tokio::sync::mpsc; -use tui::backend::Backend; -use tui::layout::{Alignment, Constraint, Direction, Layout}; -use tui::style::{Color, Modifier, Style}; -use tui::text::{Line, Span}; -use tui::widgets::{ +use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; +use ratatui::layout::{Alignment, Constraint, Direction, Layout}; +use ratatui::style::{Color, Modifier, Style}; +use ratatui::text::{Line, Span}; +use ratatui::widgets::{ Axis, BarChart, Block, Borders, Chart, Dataset, Gauge, GraphType, Paragraph, Tabs, }; -use tui::{symbols, Frame}; +use ratatui::{symbols, Frame}; +use text_generation_client::ClientError; +use tokio::sync::mpsc; /// TUI powered App pub(crate) struct App { @@ -153,7 +152,7 @@ impl App { } /// Render frame - pub fn render(&mut self, f: &mut Frame<'_, B>) { + pub fn render(&mut self, f: &mut Frame) { let batch_progress = (self.completed_batch as f64 / self.data.batch_size.len() as f64).clamp(0.0, 1.0); let run_progress = @@ -172,7 +171,7 @@ impl App { ] .as_ref(), ) - .split(f.size()); + .split(f.area()); // Top row horizontal layout let top = Layout::default() @@ -239,7 +238,7 @@ impl App { f.render_widget(helper, row5[0]); // Batch tabs - let titles = self + let titles: Vec = self .data .batch_size .iter() @@ -497,7 +496,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec> { "Lowest: {:.2} {unit}", data.iter() .min_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN) + .unwrap_or(&f64::NAN) ), Style::default().fg(Color::Reset), )]), @@ -506,7 +505,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec> { "Highest: {:.2} {unit}", data.iter() .max_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN) + .unwrap_or(&f64::NAN) ), Style::default().fg(Color::Reset), )]), @@ -555,17 +554,17 @@ fn latency_throughput_chart<'a>( let min_latency: f64 = *latency_iter .clone() .min_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); let max_latency: f64 = *latency_iter .max_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); let min_throughput: f64 = *throughput_iter .clone() .min_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); let max_throughput: f64 = *throughput_iter .max_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); // Char min max values let min_x = if zoom { diff --git a/benchmark/src/event.rs b/benchmark/src/event.rs index 07482aed..d3f10fb6 100644 --- a/benchmark/src/event.rs +++ b/benchmark/src/event.rs @@ -1,5 +1,5 @@ /// Inspired by https://github.com/orhun/rust-tui-template/blob/472aa515119d4c94903eac12d9784417281dc7f5/src/event.rs -use crossterm::event; +use ratatui::crossterm::event; use std::time::{Duration, Instant}; use tokio::sync::{broadcast, mpsc}; diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index b2766d0c..789c7b51 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -1,8 +1,9 @@ use std::time::{Duration, Instant}; -use text_generation_client::{ - Batch, CachedBatch, ClientError, NextTokenChooserParameters, Request, ShardedClient, +use text_generation_client::v3::{ + Batch, CachedBatch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, }; +use text_generation_client::{Chunk, ClientError, Input}; use tokenizers::{Tokenizer, TruncationDirection}; use tokio::sync::{broadcast, mpsc}; @@ -142,8 +143,12 @@ async fn prefill( .map(|id| Request { id: id.into(), prefill_logprobs: false, + input_chunks: Some(Input { + chunks: vec![Chunk::Text(sequence.clone()).into()], + }), inputs: sequence.clone(), truncate: sequence_length, + add_special_tokens: true, parameters: Some(parameters.clone()), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: decode_length, @@ -151,6 +156,10 @@ async fn prefill( ignore_eos_token: true, // Will not stop even if a eos token is generated }), top_n_tokens: top_n_tokens.unwrap_or(0), + blocks: vec![], + slots: vec![], + prefix_len: 0, + adapter_id: None, }) .collect(); @@ -159,15 +168,13 @@ async fn prefill( requests, size: batch_size, max_tokens: batch_size * (sequence_length + decode_length), + max_blocks: 0, }; // Run prefill let start_time = Instant::now(); - let (_, decode_batch, _) = client.prefill(batch.clone()).await?; - let (_, decode_batch, _) = client.decode(vec![decode_batch.clone().unwrap()]).await?; - // Get latency let latency = start_time.elapsed(); @@ -183,12 +190,11 @@ async fn prefill( }; Ok((step, decode_batch)) - } /// Run a full decode async fn decode(batch: CachedBatch, client: &mut ShardedClient) -> Result { - let mut decode_length = 1; // 1 decode step was already scheduled in prefill with speculative scheduling + let mut decode_length = 0; let batch_size = batch.size; let start_time = Instant::now(); diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index 638c6514..bb4b6a77 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -6,13 +6,13 @@ mod utils; use crate::app::App; use crate::event::Event; -use crossterm::ExecutableCommand; +use ratatui::backend::CrosstermBackend; +use ratatui::crossterm::ExecutableCommand; +use ratatui::Terminal; use std::io; -use text_generation_client::{GrammarType, NextTokenChooserParameters, ShardedClient}; +use text_generation_client::v3::{GrammarType, NextTokenChooserParameters, ShardedClient}; use tokenizers::Tokenizer; use tokio::sync::{broadcast, mpsc}; -use tui::backend::CrosstermBackend; -use tui::Terminal; /// Run benchmarking app #[allow(clippy::too_many_arguments)] @@ -50,9 +50,9 @@ pub async fn run( }; // Initialize terminal properties - crossterm::terminal::enable_raw_mode()?; - io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?; - io::stdout().execute(crossterm::cursor::Hide)?; + ratatui::crossterm::terminal::enable_raw_mode()?; + io::stdout().execute(ratatui::crossterm::terminal::EnterAlternateScreen)?; + io::stdout().execute(ratatui::crossterm::cursor::Hide)?; // Initialize terminal let mut terminal = { @@ -128,9 +128,9 @@ pub async fn run( let _ = shutdown_guard_receiver.recv().await; // Revert terminal to original view - io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?; - crossterm::terminal::disable_raw_mode()?; - io::stdout().execute(crossterm::cursor::Show)?; + io::stdout().execute(ratatui::crossterm::terminal::LeaveAlternateScreen)?; + ratatui::crossterm::terminal::disable_raw_mode()?; + io::stdout().execute(ratatui::crossterm::cursor::Show)?; let parameters_table = table::parameters_table( tokenizer_name, diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index 935808b6..2ee3d7c5 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -4,7 +4,7 @@ /// and: https://github.com/orhun/rust-tui-template use clap::Parser; use std::path::Path; -use text_generation_client::ShardedClient; +use text_generation_client::v3::ShardedClient; use tokenizers::{FromPretrainedParameters, Tokenizer}; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; @@ -51,7 +51,7 @@ struct Args { runs: usize, /// Number of warmup cycles - #[clap(default_value = "3", short, long, env)] + #[clap(default_value = "1", short, long, env)] warmups: usize, /// The location of the grpc socket. This benchmark tool bypasses the router @@ -155,7 +155,7 @@ fn main() -> Result<(), Box> { // We need to download it outside of the Tokio runtime let params = FromPretrainedParameters { revision, - token: auth_token, + auth_token, ..Default::default() }; Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).unwrap() diff --git a/benchmark/src/table.rs b/benchmark/src/table.rs index e18d7310..1585a25f 100644 --- a/benchmark/src/table.rs +++ b/benchmark/src/table.rs @@ -156,17 +156,17 @@ fn avg_min_max(data: &[f64]) -> (f64, f64, f64) { let min = data .iter() .min_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); let max = data .iter() .max_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); (average, *min, *max) } fn px(data: &[f64], p: u32) -> f64 { let i = (f64::from(p) / 100.0 * data.len() as f64) as usize; - *data.get(i).unwrap_or(&std::f64::NAN) + *data.get(i).unwrap_or(&f64::NAN) } fn format_value(value: f64, unit: &'static str) -> String { diff --git a/benchmark/src/utils.rs b/benchmark/src/utils.rs index d096d655..20469991 100644 --- a/benchmark/src/utils.rs +++ b/benchmark/src/utils.rs @@ -37,7 +37,7 @@ pub(crate) fn percentiles(values: &[f64], pecents: &[i32]) -> BTreeMap List[DeployedModel]: List[DeployedModel]: list of all currently deployed models """ resp = requests.get( - f"https://api-inference.huggingface.co/framework/text-generation-inference", + "https://api-inference.huggingface.co/framework/text-generation-inference", headers=headers, timeout=5, ) diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index eb872ee6..1085075e 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -1,5 +1,5 @@ from enum import Enum -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, field_validator, ConfigDict from typing import Optional, List, Union, Any from text_generation.errors import ValidationError @@ -28,11 +28,17 @@ class ToolCall(BaseModel): function: dict +class Chunk(BaseModel): + type: str + text: Optional[str] = None + image_url: Any = None + + class Message(BaseModel): # Role of the message sender role: str # Content of the message - content: Optional[str] = None + content: Optional[Union[str, List[Chunk]]] = None # Optional name of the message sender name: Optional[str] = None # Tool calls associated with the chat completion @@ -61,7 +67,7 @@ class ChoiceDeltaToolCall(BaseModel): class ChoiceDelta(BaseModel): role: str content: Optional[str] = None - tool_calls: Optional[ChoiceDeltaToolCall] + tool_calls: Optional[ChoiceDeltaToolCall] = None class Choice(BaseModel): @@ -168,7 +174,7 @@ class ChatCompletionComplete(BaseModel): # Log probabilities for the chat completion logprobs: Optional[Any] # Reason for completion - finish_reason: str + finish_reason: Optional[str] # Usage details of the chat completion usage: Optional[Any] = None @@ -191,6 +197,7 @@ class ChatCompletionChunk(BaseModel): model: str system_fingerprint: str choices: List[Choice] + usage: Optional[Any] = None class Parameters(BaseModel): @@ -452,5 +459,9 @@ class StreamResponse(BaseModel): # Inference API currently deployed model class DeployedModel(BaseModel): + # Disable warning for use of `model_` prefix in `model_id`. Be mindful about adding members + # with model_ prefixes, since this disables guardrails for colliding fields: + # https://github.com/pydantic/pydantic/issues/9177 + model_config = ConfigDict(protected_namespaces=()) model_id: str sha: str diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 00000000..fb2ff198 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,10 @@ +Documentation available at: https://huggingface.co/docs/text-generation-inference + +## Release + +When making a release, please update the latest version in the documentation with: +``` +export OLD_VERSION="2\.0\.3" +export NEW_VERSION="2\.0\.4" +find . -name '*.md' -exec sed -i -e "s/$OLD_VERSION/$NEW_VERSION/g" {} \; +``` diff --git a/docs/openapi.json b/docs/openapi.json index 79c3b80f..5854bcdd 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -10,7 +10,7 @@ "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0" }, - "version": "2.0.1" + "version": "2.3.1" }, "paths": { "/": { @@ -19,7 +19,6 @@ "Text Generation Inference" ], "summary": "Generate tokens if `stream == false` or a stream of token if `stream == true`", - "description": "Generate tokens if `stream == false` or a stream of token if `stream == true`", "operationId": "compat_generate", "requestBody": { "content": { @@ -108,7 +107,6 @@ "Text Generation Inference" ], "summary": "Generate tokens", - "description": "Generate tokens", "operationId": "generate", "requestBody": { "content": { @@ -192,7 +190,6 @@ "Text Generation Inference" ], "summary": "Generate a stream of token using Server-Sent Events", - "description": "Generate a stream of token using Server-Sent Events", "operationId": "generate_stream", "requestBody": { "content": { @@ -276,7 +273,6 @@ "Text Generation Inference" ], "summary": "Health check method", - "description": "Health check method", "operationId": "health", "responses": { "200": { @@ -305,7 +301,6 @@ "Text Generation Inference" ], "summary": "Text Generation Inference endpoint info", - "description": "Text Generation Inference endpoint info", "operationId": "get_model_info", "responses": { "200": { @@ -327,7 +322,6 @@ "Text Generation Inference" ], "summary": "Prometheus metrics scrape endpoint", - "description": "Prometheus metrics scrape endpoint", "operationId": "metrics", "responses": { "200": { @@ -349,7 +343,6 @@ "Text Generation Inference" ], "summary": "Tokenize inputs", - "description": "Tokenize inputs", "operationId": "tokenize", "requestBody": { "content": { @@ -394,7 +387,6 @@ "Text Generation Inference" ], "summary": "Generate tokens", - "description": "Generate tokens", "operationId": "chat_completions", "requestBody": { "content": { @@ -483,7 +475,6 @@ "Text Generation Inference" ], "summary": "Generate tokens", - "description": "Generate tokens", "operationId": "completions", "requestBody": { "content": { @@ -501,12 +492,12 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/Completion" + "$ref": "#/components/schemas/CompletionFinal" } }, "text/event-stream": { "schema": { - "$ref": "#/components/schemas/CompletionCompleteChunk" + "$ref": "#/components/schemas/Chunk" } } } @@ -565,6 +556,37 @@ } } } + }, + "/v1/models": { + "get": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Get model info", + "operationId": "openai_get_model_info", + "responses": { + "200": { + "description": "Served model info", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ModelInfo" + } + } + } + }, + "404": { + "description": "Model not found", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + } + } + } + } + } + } } }, "components": { @@ -626,7 +648,6 @@ "type": "object", "required": [ "id", - "object", "created", "model", "system_fingerprint", @@ -653,9 +674,6 @@ "type": "string", "example": "mistralai/Mistral-7B-Instruct-v0.2" }, - "object": { - "type": "string" - }, "system_fingerprint": { "type": "string" }, @@ -697,7 +715,6 @@ "type": "object", "required": [ "id", - "object", "created", "model", "system_fingerprint", @@ -723,11 +740,16 @@ "type": "string", "example": "mistralai/Mistral-7B-Instruct-v0.2" }, - "object": { - "type": "string" - }, "system_fingerprint": { "type": "string" + }, + "usage": { + "allOf": [ + { + "$ref": "#/components/schemas/Usage" + } + ], + "nullable": true } } }, @@ -756,34 +778,19 @@ "nullable": true }, "message": { - "$ref": "#/components/schemas/Message" + "$ref": "#/components/schemas/OutputMessage" } } }, "ChatCompletionDelta": { - "type": "object", - "required": [ - "role" - ], - "properties": { - "content": { - "type": "string", - "example": "What is Deep Learning?", - "nullable": true + "oneOf": [ + { + "$ref": "#/components/schemas/TextMessage" }, - "role": { - "type": "string", - "example": "user" - }, - "tool_calls": { - "allOf": [ - { - "$ref": "#/components/schemas/DeltaToolCall" - } - ], - "nullable": true + { + "$ref": "#/components/schemas/ToolCallDelta" } - } + ] }, "ChatCompletionLogprob": { "type": "object", @@ -841,7 +848,6 @@ "ChatRequest": { "type": "object", "required": [ - "model", "messages" ], "properties": { @@ -852,6 +858,13 @@ "example": "1.0", "nullable": true }, + "guideline": { + "type": "string", + "description": "A guideline to be used in the chat_template", + "default": "null", + "example": "null", + "nullable": true + }, "logit_bias": { "type": "array", "items": { @@ -886,7 +899,8 @@ "model": { "type": "string", "description": "[UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", - "example": "mistralai/Mistral-7B-Instruct-v0.2" + "example": "mistralai/Mistral-7B-Instruct-v0.2", + "nullable": true }, "n": { "type": "integer", @@ -903,6 +917,15 @@ "example": 0.1, "nullable": true }, + "response_format": { + "allOf": [ + { + "$ref": "#/components/schemas/GrammarType" + } + ], + "default": "null", + "nullable": true + }, "seed": { "type": "integer", "format": "int64", @@ -922,6 +945,14 @@ "stream": { "type": "boolean" }, + "stream_options": { + "allOf": [ + { + "$ref": "#/components/schemas/StreamOptions" + } + ], + "nullable": true + }, "temperature": { "type": "number", "format": "float", @@ -932,7 +963,7 @@ "tool_choice": { "allOf": [ { - "$ref": "#/components/schemas/ToolType" + "$ref": "#/components/schemas/ToolChoice" } ], "nullable": true @@ -940,7 +971,7 @@ "tool_prompt": { "type": "string", "description": "A prompt to be appended before the tools", - "example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"", + "example": "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.", "nullable": true }, "tools": { @@ -969,6 +1000,38 @@ } } }, + "Chunk": { + "type": "object", + "required": [ + "id", + "created", + "choices", + "model", + "system_fingerprint" + ], + "properties": { + "choices": { + "type": "array", + "items": { + "$ref": "#/components/schemas/CompletionComplete" + } + }, + "created": { + "type": "integer", + "format": "int64", + "minimum": 0 + }, + "id": { + "type": "string" + }, + "model": { + "type": "string" + }, + "system_fingerprint": { + "type": "string" + } + } + }, "CompatGenerateRequest": { "type": "object", "required": [ @@ -988,6 +1051,55 @@ } } }, + "Completion": { + "oneOf": [ + { + "allOf": [ + { + "$ref": "#/components/schemas/Chunk" + }, + { + "type": "object", + "required": [ + "object" + ], + "properties": { + "object": { + "type": "string", + "enum": [ + "text_completion" + ] + } + } + } + ] + }, + { + "allOf": [ + { + "$ref": "#/components/schemas/CompletionFinal" + }, + { + "type": "object", + "required": [ + "object" + ], + "properties": { + "object": { + "type": "string", + "enum": [ + "text_completion" + ] + } + } + } + ] + } + ], + "discriminator": { + "propertyName": "object" + } + }, "CompletionComplete": { "type": "object", "required": [ @@ -1017,15 +1129,15 @@ } } }, - "CompletionCompleteChunk": { + "CompletionFinal": { "type": "object", "required": [ "id", - "object", "created", - "choices", "model", - "system_fingerprint" + "system_fingerprint", + "choices", + "usage" ], "properties": { "choices": { @@ -1037,26 +1149,27 @@ "created": { "type": "integer", "format": "int64", + "example": "1706270835", "minimum": 0 }, "id": { "type": "string" }, "model": { - "type": "string" - }, - "object": { - "type": "string" + "type": "string", + "example": "mistralai/Mistral-7B-Instruct-v0.2" }, "system_fingerprint": { "type": "string" + }, + "usage": { + "$ref": "#/components/schemas/Usage" } } }, "CompletionRequest": { "type": "object", "required": [ - "model", "prompt" ], "properties": { @@ -1078,15 +1191,11 @@ "model": { "type": "string", "description": "UNUSED\nID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", - "example": "mistralai/Mistral-7B-Instruct-v0.2" + "example": "mistralai/Mistral-7B-Instruct-v0.2", + "nullable": true }, "prompt": { - "type": "array", - "items": { - "type": "string" - }, - "description": "The prompt to generate completions for.", - "example": "What is Deep Learning?" + "$ref": "#/components/schemas/Prompt" }, "repetition_penalty": { "type": "number", @@ -1100,6 +1209,15 @@ "nullable": true, "minimum": 0 }, + "stop": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Up to 4 sequences where the API will stop generating further tokens.", + "example": "null", + "nullable": true + }, "stream": { "type": "boolean" }, @@ -1121,15 +1239,6 @@ "description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.", "example": 0.95, "nullable": true - }, - "stop": { - "type": "array", - "items": { - "type": "string" - }, - "description": "Up to 4 sequences where the API will stop generating further tokens.", - "example": "null", - "nullable": true } } }, @@ -1269,11 +1378,30 @@ } } }, + "FunctionName": { + "type": "object", + "required": [ + "name" + ], + "properties": { + "name": { + "type": "string" + } + } + }, "GenerateParameters": { "type": "object", "properties": { + "adapter_id": { + "type": "string", + "description": "Lora adapter id", + "default": "null", + "example": "null", + "nullable": true + }, "best_of": { "type": "integer", + "description": "Generate best_of sequences and return the one if the highest token logprobs.", "default": "null", "example": 1, "nullable": true, @@ -1282,20 +1410,24 @@ }, "decoder_input_details": { "type": "boolean", + "description": "Whether to return decoder input token logprobs and ids.", "default": "false" }, "details": { "type": "boolean", + "description": "Whether to return generation details.", "default": "true" }, "do_sample": { "type": "boolean", + "description": "Activate logits sampling.", "default": "false", "example": true }, "frequency_penalty": { "type": "number", "format": "float", + "description": "The parameter for frequency penalty. 1.0 means no penalty\nPenalize new tokens based on their existing frequency in the text so far,\ndecreasing the model's likelihood to repeat the same line verbatim.", "default": "null", "example": 0.1, "nullable": true, @@ -1313,6 +1445,7 @@ "max_new_tokens": { "type": "integer", "format": "int32", + "description": "Maximum number of tokens to generate.", "default": "100", "example": "20", "nullable": true, @@ -1321,6 +1454,7 @@ "repetition_penalty": { "type": "number", "format": "float", + "description": "The parameter for repetition penalty. 1.0 means no penalty.\nSee [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.", "default": "null", "example": 1.03, "nullable": true, @@ -1328,6 +1462,7 @@ }, "return_full_text": { "type": "boolean", + "description": "Whether to prepend the prompt to the generated text", "default": "null", "example": false, "nullable": true @@ -1335,6 +1470,7 @@ "seed": { "type": "integer", "format": "int64", + "description": "Random sampling seed.", "default": "null", "example": "null", "nullable": true, @@ -1346,6 +1482,7 @@ "items": { "type": "string" }, + "description": "Stop generating tokens if a member of `stop` is generated.", "example": [ "photographer" ], @@ -1354,6 +1491,7 @@ "temperature": { "type": "number", "format": "float", + "description": "The value used to module the logits distribution.", "default": "null", "example": 0.5, "nullable": true, @@ -1362,6 +1500,7 @@ "top_k": { "type": "integer", "format": "int32", + "description": "The number of highest probability vocabulary tokens to keep for top-k-filtering.", "default": "null", "example": 10, "nullable": true, @@ -1370,6 +1509,7 @@ "top_n_tokens": { "type": "integer", "format": "int32", + "description": "The number of highest probability vocabulary tokens to keep for top-n-filtering.", "default": "null", "example": 5, "nullable": true, @@ -1379,6 +1519,7 @@ "top_p": { "type": "number", "format": "float", + "description": "Top-p value for nucleus sampling.", "default": "null", "example": 0.95, "nullable": true, @@ -1387,6 +1528,7 @@ }, "truncate": { "type": "integer", + "description": "Truncate inputs tokens to the given size.", "default": "null", "example": "null", "nullable": true, @@ -1395,6 +1537,7 @@ "typical_p": { "type": "number", "format": "float", + "description": "Typical Decoding mass\nSee [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.", "default": "null", "example": 0.95, "nullable": true, @@ -1403,6 +1546,7 @@ }, "watermark": { "type": "boolean", + "description": "Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226).", "default": "false", "example": true } @@ -1490,18 +1634,14 @@ "type": "object", "required": [ "model_id", - "model_dtype", - "model_device_type", "max_concurrent_requests", "max_best_of", "max_stop_sequences", - "max_input_length", + "max_input_tokens", "max_total_tokens", - "waiting_served_ratio", - "max_batch_total_tokens", - "max_waiting_tokens", "validation_workers", "max_client_batch_size", + "router", "version" ], "properties": { @@ -1510,18 +1650,6 @@ "example": "null", "nullable": true }, - "max_batch_size": { - "type": "integer", - "example": "null", - "nullable": true, - "minimum": 0 - }, - "max_batch_total_tokens": { - "type": "integer", - "format": "int32", - "example": "32000", - "minimum": 0 - }, "max_best_of": { "type": "integer", "example": "2", @@ -1538,7 +1666,7 @@ "example": "128", "minimum": 0 }, - "max_input_length": { + "max_input_tokens": { "type": "integer", "example": "1024", "minimum": 0 @@ -1553,19 +1681,6 @@ "example": "2048", "minimum": 0 }, - "max_waiting_tokens": { - "type": "integer", - "example": "20", - "minimum": 0 - }, - "model_device_type": { - "type": "string", - "example": "cuda" - }, - "model_dtype": { - "type": "string", - "example": "torch.float16" - }, "model_id": { "type": "string", "description": "Model info", @@ -1581,6 +1696,11 @@ "example": "e985a63cdc139290c5f700ff1929f0b5942cced2", "nullable": true }, + "router": { + "type": "string", + "description": "Router Info", + "example": "text-generation-router" + }, "sha": { "type": "string", "example": "null", @@ -1593,26 +1713,19 @@ }, "version": { "type": "string", - "description": "Router Info", "example": "0.5.0" - }, - "waiting_served_ratio": { - "type": "number", - "format": "float", - "example": "1.2" } } }, "Message": { "type": "object", "required": [ - "role" + "role", + "content" ], "properties": { "content": { - "type": "string", - "example": "My name is David and I", - "nullable": true + "$ref": "#/components/schemas/MessageContent" }, "name": { "type": "string", @@ -1622,16 +1735,104 @@ "role": { "type": "string", "example": "user" - }, - "tool_calls": { - "type": "array", - "items": { - "$ref": "#/components/schemas/ToolCall" - }, - "nullable": true } } }, + "MessageChunk": { + "oneOf": [ + { + "type": "object", + "required": [ + "text", + "type" + ], + "properties": { + "text": { + "type": "string" + }, + "type": { + "type": "string", + "enum": [ + "text" + ] + } + } + }, + { + "type": "object", + "required": [ + "image_url", + "type" + ], + "properties": { + "image_url": { + "$ref": "#/components/schemas/Url" + }, + "type": { + "type": "string", + "enum": [ + "image_url" + ] + } + } + } + ], + "discriminator": { + "propertyName": "type" + } + }, + "MessageContent": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "$ref": "#/components/schemas/MessageChunk" + } + } + ] + }, + "ModelInfo": { + "type": "object", + "required": [ + "id", + "object", + "created", + "owned_by" + ], + "properties": { + "created": { + "type": "integer", + "format": "int64", + "example": 1686935002, + "minimum": 0 + }, + "id": { + "type": "string", + "example": "gpt2" + }, + "object": { + "type": "string", + "example": "model" + }, + "owned_by": { + "type": "string", + "example": "openai" + } + } + }, + "OutputMessage": { + "oneOf": [ + { + "$ref": "#/components/schemas/TextMessage" + }, + { + "$ref": "#/components/schemas/ToolCallMessage" + } + ] + }, "PrefillToken": { "type": "object", "required": [ @@ -1658,6 +1859,12 @@ } } }, + "Prompt": { + "type": "array", + "items": { + "type": "string" + } + }, "SimpleToken": { "type": "object", "required": [ @@ -1693,7 +1900,8 @@ "type": "object", "required": [ "finish_reason", - "generated_tokens" + "generated_tokens", + "input_length" ], "properties": { "finish_reason": { @@ -1705,6 +1913,12 @@ "example": 1, "minimum": 0 }, + "input_length": { + "type": "integer", + "format": "int32", + "example": 1, + "minimum": 0 + }, "seed": { "type": "integer", "format": "int64", @@ -1714,6 +1928,19 @@ } } }, + "StreamOptions": { + "type": "object", + "required": [ + "include_usage" + ], + "properties": { + "include_usage": { + "type": "boolean", + "description": "If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.", + "example": "true" + } + } + }, "StreamResponse": { "type": "object", "required": [ @@ -1752,6 +1979,23 @@ } } }, + "TextMessage": { + "type": "object", + "required": [ + "role", + "content" + ], + "properties": { + "content": { + "type": "string", + "example": "My name is David and I" + }, + "role": { + "type": "string", + "example": "user" + } + } + }, "Token": { "type": "object", "required": [ @@ -1817,36 +2061,95 @@ "$ref": "#/components/schemas/FunctionDefinition" }, "id": { - "type": "integer", - "format": "int32", - "minimum": 0 + "type": "string" }, "type": { "type": "string" } } }, + "ToolCallDelta": { + "type": "object", + "required": [ + "role", + "tool_calls" + ], + "properties": { + "role": { + "type": "string", + "example": "assistant" + }, + "tool_calls": { + "$ref": "#/components/schemas/DeltaToolCall" + } + } + }, + "ToolCallMessage": { + "type": "object", + "required": [ + "role", + "tool_calls" + ], + "properties": { + "role": { + "type": "string", + "example": "assistant" + }, + "tool_calls": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolCall" + } + } + } + }, + "ToolChoice": { + "allOf": [ + { + "$ref": "#/components/schemas/ToolType" + } + ], + "nullable": true + }, "ToolType": { "oneOf": [ + { + "type": "object", + "default": null, + "nullable": true + }, + { + "type": "string" + }, { "type": "object", "required": [ - "FunctionName" + "function" ], "properties": { - "FunctionName": { - "type": "string" + "function": { + "$ref": "#/components/schemas/FunctionName" } } }, { - "type": "string", - "enum": [ - "OneOf" - ] + "type": "object", + "default": null, + "nullable": true } ] }, + "Url": { + "type": "object", + "required": [ + "url" + ], + "properties": { + "url": { + "type": "string" + } + } + }, "Usage": { "type": "object", "required": [ diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index a7351a33..b883b36d 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -11,12 +11,16 @@ title: Using TGI with Intel Gaudi - local: installation_inferentia title: Using TGI with AWS Inferentia + - local: installation_intel + title: Using TGI with Intel GPUs - local: installation title: Installation from source - local: supported_models title: Supported Models and Hardware - - local: messages_api - title: Messages API + - local: architecture + title: Internal Architecture + - local: usage_statistics + title: Usage Statistics title: Getting started - sections: - local: basic_tutorials/consuming_tgi @@ -27,8 +31,6 @@ title: Serving Private & Gated Models - local: basic_tutorials/using_cli title: Using TGI CLI - - local: basic_tutorials/launcher - title: All TGI CLI options - local: basic_tutorials/non_core_models title: Non-core Model Serving - local: basic_tutorials/safety @@ -42,6 +44,14 @@ - local: basic_tutorials/train_medusa title: Train Medusa title: Tutorials +- sections: + - local: reference/launcher + title: All TGI CLI options + - local: reference/metrics + title: Exported Metrics + - local: reference/api_reference + title: API Reference + title: Reference - sections: - local: conceptual/streaming title: Streaming @@ -59,5 +69,10 @@ title: Speculation (Medusa, ngram) - local: conceptual/guidance title: How Guidance Works (via outlines) + - local: conceptual/lora + title: LoRA (Low-Rank Adaptation) + - local: conceptual/external + title: External Resources + title: Conceptual Guides diff --git a/docs/source/architecture.md b/docs/source/architecture.md new file mode 100644 index 00000000..6660630d --- /dev/null +++ b/docs/source/architecture.md @@ -0,0 +1,232 @@ +# Text Generation Inference Architecture + +This document aims at describing the architecture of Text Generation Inference (TGI), by describing the call flow between the separate components. + +A high-level architecture diagram can be seen here: + +![TGI architecture](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/TGI.png) + +This diagram shows well there are these separate components: + +- **The router**, also named `webserver`, that receives the client requests, buffers them, creates some batches, and prepares gRPC calls to a model server. +- **The model server**, responsible of receiving the gRPC requests and to process the inference on the model. If the model is sharded across multiple accelerators (e.g.: multiple GPUs), the model server shards might be synchronized via NCCL or equivalent. +- **The launcher** is a helper that will be able to launch one or several model servers (if model is sharded), and it launches the router with the compatible arguments. + +The router and the model server can be two different machines, they do not need to be deployed together. + +## The Router + +This component is a rust web server binary that accepts HTTP requests using the custom [HTTP API](https://huggingface.github.io/text-generation-inference/), as well as OpenAI's [Messages API](https://huggingface.co/docs/text-generation-inference/messages_api). +The router receives the API calls and handles the "baches" logic (and introduction to batching can be found [here](https://github.com/huggingface/text-generation-inference/blob/main/router/README.md)). +It uses different strategies to reduce latency between requests and responses, especially oriented to decoding latency. It will use queues, schedulers, and block allocators to achieve that and produce batched requests that it will then be sent to the model server. + +### Router's command line + +The router command line will be the way to pass parameters to it (it does not rely on configuration file): + +``` +Text Generation Webserver + +Usage: text-generation-router [OPTIONS] + +Options: + --max-concurrent-requests + [env: MAX_CONCURRENT_REQUESTS=] [default: 128] + --max-best-of + [env: MAX_BEST_OF=] [default: 2] + --max-stop-sequences + [env: MAX_STOP_SEQUENCES=] [default: 4] + --max-top-n-tokens + [env: MAX_TOP_N_TOKENS=] [default: 5] + --max-input-tokens + [env: MAX_INPUT_TOKENS=] [default: 1024] + --max-total-tokens + [env: MAX_TOTAL_TOKENS=] [default: 2048] + --waiting-served-ratio + [env: WAITING_SERVED_RATIO=] [default: 1.2] + --max-batch-prefill-tokens + [env: MAX_BATCH_PREFILL_TOKENS=] [default: 4096] + --max-batch-total-tokens + [env: MAX_BATCH_TOTAL_TOKENS=] + --max-waiting-tokens + [env: MAX_WAITING_TOKENS=] [default: 20] + --max-batch-size + [env: MAX_BATCH_SIZE=] + --hostname + [env: HOSTNAME=] [default: 0.0.0.0] + -p, --port + [env: PORT=] [default: 3000] + --master-shard-uds-path + [env: MASTER_SHARD_UDS_PATH=] [default: /tmp/text-generation-server-0] + --tokenizer-name + [env: TOKENIZER_NAME=] [default: bigscience/bloom] + --tokenizer-config-path + [env: TOKENIZER_CONFIG_PATH=] + --revision + [env: REVISION=] + --validation-workers + [env: VALIDATION_WORKERS=] [default: 2] + --json-output + [env: JSON_OUTPUT=] + --otlp-endpoint + [env: OTLP_ENDPOINT=] + --otlp-service-name + [env: OTLP_SERVICE_NAME=] + --cors-allow-origin + [env: CORS_ALLOW_ORIGIN=] + --ngrok + [env: NGROK=] + --ngrok-authtoken + [env: NGROK_AUTHTOKEN=] + --ngrok-edge + [env: NGROK_EDGE=] + --messages-api-enabled + [env: MESSAGES_API_ENABLED=] + --disable-grammar-support + [env: DISABLE_GRAMMAR_SUPPORT=] + --max-client-batch-size + [env: MAX_CLIENT_BATCH_SIZE=] [default: 4] + -h, --help + Print help + -V, --version + Print version +``` + +## The Model Server + +The model server is a python server, capable of starting a server waiting for gRPC requests, loads a given model, perform sharding to provide [tensor parallelism](https://huggingface.co/docs/text-generation-inference/conceptual/tensor_parallelism), and stays alive while waiting for new requests. +The model server supports models instantiated using Pytorch and optimized for inference mainly on CUDA/ROCM. + +### Model Server Variants + +Several variants of the model server exist that are actively supported by Hugging Face: + +- By default, the model server will attempt building [a server optimized for Nvidia GPUs with CUDA](https://huggingface.co/docs/text-generation-inference/installation_nvidia). The code for this version is hosted in the [main TGI repository](https://github.com/huggingface/text-generation-inference). +- A [version optimized for AMD with ROCm](https://huggingface.co/docs/text-generation-inference/installation_amd) is hosted in the main TGI repository. Some model features differ. +- A [version optimized for Intel GPUs](https://huggingface.co/docs/text-generation-inference/installation_intel) is hosted in the main TGI repository. Some model features differ. +- The [version for Intel Gaudi](https://huggingface.co/docs/text-generation-inference/installation_gaudi) is maintained on a forked repository, often resynchronized with the main [TGI repository](https://github.com/huggingface/tgi-gaudi). +- A [version for Neuron (AWS Inferentia2)](https://huggingface.co/docs/text-generation-inference/installation_inferentia) is maintained as part of [Optimum Neuron](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference). +- A version for Google TPUs is maintained as part of [Optimum TPU](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference). + +Not all variants provide the same features, as hardware and middleware capabilities do not provide the same optimizations. + +### Command Line Interface + +The official command line interface (CLI) for the server supports three subcommands, `download-weights`, `quantize` and `serve`: + +- `download-weights` will download weights from the hub and, in some variants it will convert weights to a format that is adapted to the given implementation; +- `quantize` will allow to quantize a model using the `qptq` package. This feature is not available nor supported on all variants; +- `serve` will start the server that load a model (or a model shard), receives gRPC calls from the router, performs an inference and provides a formatted response to the given request. + +Serve's command line parameters on the TGI repository are these: + +``` + Usage: cli.py serve [OPTIONS] MODEL_ID + +╭─ Arguments ──────────────────────────────────────────────────────────────────────────────────────────────╮ +│ * model_id TEXT [default: None] [required] │ +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ --revision TEXT [default: None] │ +│ --sharded --no-sharded [default: no-sharded] │ +│ --quantize [bitsandbytes|bitsandbytes [default: None] │ +│ -nf4|bitsandbytes-fp4|gptq │ +│ |awq|eetq|exl2|fp8] │ +│ --speculate INTEGER [default: None] │ +│ --dtype [float16|bfloat16] [default: None] │ +│ --trust-remote-code --no-trust-remote-code [default: │ +│ no-trust-remote-code] │ +│ --uds-path PATH [default: │ +│ /tmp/text-generation-serve… │ +│ --logger-level TEXT [default: INFO] │ +│ --json-output --no-json-output [default: no-json-output] │ +│ --otlp-endpoint TEXT [default: None] │ +│ --otlp-service-name TEXT [default: │ +│ text-generation-inference...│ +│ --help Show this message and exit. │ +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +``` + +Note that some variants might support different parameters, and they could possibly accept more options that can be passed on using environment variables. + +## Call Flow + +Once both components are initialized, weights downloaded and model server is up and running, router and model server exchange data and info through the gRPC call. There are currently two supported schemas, [v2](https://github.com/huggingface/text-generation-inference/blob/main/proto/generate.proto) and [v3](https://github.com/huggingface/text-generation-inference/blob/main/proto/v3/generate.proto). These two versions are almost identical, except for: + +- input chunks support, for text and image data, +- paged attention support + +Here's a diagram that displays the exchanges that follow the router and model server startup. + +```mermaid +sequenceDiagram + + Router->>Model Server: service discovery + Model Server-->>Router: urls for other shards + + Router->>Model Server: get model info + Model Server-->>Router: shard info + + Router->>Model Server: health check + Model Server-->>Router: health OK + + Router->>Model Server: warmup(max_input_tokens, max_batch_prefill_tokens, max_total_tokens, max_batch_size) + Model Server-->>Router: warmup result +``` + +After these are done, the router is ready to receive generate calls from multiple clients. Here's an example. + +```mermaid +sequenceDiagram + participant Client 1 + participant Client 2 + participant Client 3 + participant Router + participant Model Server + + Client 1->>Router: generate_stream + Router->>Model Server: prefill(batch1) + Model Server-->>Router: generations, cached_batch1, timings + Router-->>Client 1: token 1 + + Router->>Model Server: decode(cached_batch1) + Model Server-->>Router: generations, cached_batch1, timings + Router-->>Client 1: token 2 + + Router->>Model Server: decode(cached_batch1) + Model Server-->>Router: generations, cached_batch1, timings + Router-->>Client 1: token 3 + + Client 2->>Router: generate_stream + Router->>Model Server: prefill(batch2) + Note right of Model Server: This stops previous batch, that is restarted + Model Server-->>Router: generations, cached_batch2, timings + Router-->>Client 2: token 1' + + Router->>Model Server: decode(cached_batch1, cached_batch2) + Model Server-->>Router: generations, cached_batch1, timings + Router-->>Client 1: token 4 + Router-->>Client 2: token 2' + + Note left of Client 1: Client 1 leaves + Router->>Model Server: filter_batch(cached_batch1, request_ids_to_keep=batch2) + Model Server-->>Router: filtered batch + + Router->>Model Server: decode(cached_batch2) + Model Server-->>Router: generations, cached_batch2, timings + Router-->>Client 2: token 3' + + Client 3->>Router: generate_stream + Note right of Model Server: This stops previous batch, that is restarted + Router->>Model Server: prefill(batch3) + Note left of Client 1: Client 3 leaves without receiving any batch + Router->>Model Server: clear_cache(batch3) + Note right of Model Server: This stops previous batch, that is restarted + + Router->>Model Server: decode(cached_batch3) + Note right of Model Server: Last token (stopping criteria) + Model Server-->>Router: generations, cached_batch3, timings + Router-->>Client 2: token 4' + + +``` diff --git a/docs/source/basic_tutorials/consuming_tgi.md b/docs/source/basic_tutorials/consuming_tgi.md index 4829ec7c..b07e7219 100644 --- a/docs/source/basic_tutorials/consuming_tgi.md +++ b/docs/source/basic_tutorials/consuming_tgi.md @@ -1,81 +1,125 @@ # Consuming Text Generation Inference -There are many ways you can consume Text Generation Inference server in your applications. After launching, you can use the `/generate` route and make a `POST` request to get results from the server. You can also use the `/generate_stream` route if you want TGI to return a stream of tokens. You can make the requests using the tool of your preference, such as curl, Python or TypeScrpt. For a final end-to-end experience, we also open-sourced ChatUI, a chat interface for open-source models. +There are many ways to consume Text Generation Inference (TGI) server in your applications. After launching the server, you can use the [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) `/v1/chat/completions` route and make a `POST` request to get results from the server. You can also pass `"stream": true` to the call if you want TGI to return a stream of tokens. + +For more information on the API, consult the OpenAPI documentation of `text-generation-inference` available [here](https://huggingface.github.io/text-generation-inference). + +You can make the requests using any tool of your preference, such as curl, Python, or TypeScript. For an end-to-end experience, we've open-sourced [ChatUI](https://github.com/huggingface/chat-ui), a chat interface for open-access models. ## curl -After the launch, you can query the model using either the `/generate` or `/generate_stream` routes: +After a successful server launch, you can query the model using the `v1/chat/completions` route, to get responses that are compliant to the OpenAI Chat Completion spec: + +```bash +curl localhost:8080/v1/chat/completions \ + -X POST \ + -d '{ + "model": "tgi", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What is deep learning?" + } + ], + "stream": true, + "max_tokens": 20 +}' \ + -H 'Content-Type: application/json' +``` + +For non-chat use-cases, you can also use the `/generate` and `/generate_stream` routes. ```bash curl 127.0.0.1:8080/generate \ -X POST \ - -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ + -d '{ + "inputs":"What is Deep Learning?", + "parameters":{ + "max_new_tokens":20 + } +}' \ -H 'Content-Type: application/json' ``` +## Python -## Inference Client +### Inference Client -[`huggingface-hub`](https://huggingface.co/docs/huggingface_hub/main/en/index) is a Python library to interact with the Hugging Face Hub, including its endpoints. It provides a nice high-level class, [`~huggingface_hub.InferenceClient`], which makes it easy to make calls to a TGI endpoint. `InferenceClient` also takes care of parameter validation and provides a simple to-use interface. -You can simply install `huggingface-hub` package with pip. +[`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/main/en/index) is a Python library to interact with the Hugging Face Hub, including its endpoints. It provides a high-level class, [`huggingface_hub.InferenceClient`](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient), which makes it easy to make calls to TGI's Messages API. `InferenceClient` also takes care of parameter validation and provides a simple-to-use interface. + +Install `huggingface_hub` package via pip. ```bash -pip install huggingface-hub +pip install huggingface_hub ``` -Once you start the TGI server, instantiate `InferenceClient()` with the URL to the endpoint serving the model. You can then call `text_generation()` to hit the endpoint through Python. +You can now use `InferenceClient` the exact same way you would use `OpenAI` client in Python ```python from huggingface_hub import InferenceClient -client = InferenceClient(model="http://127.0.0.1:8080") -client.text_generation(prompt="Write a code for snake game") +client = InferenceClient( + base_url="http://localhost:8080/v1/", +) + +output = client.chat.completions.create( + model="tgi", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Count to 10"}, + ], + stream=True, + max_tokens=1024, +) + +for chunk in output: + print(chunk.choices[0].delta.content) ``` -You can do streaming with `InferenceClient` by passing `stream=True`. Streaming will return tokens as they are being generated in the server. To use streaming, you can do as follows: +You can check out more details about OpenAI compatibility [here](https://huggingface.co/docs/huggingface_hub/en/guides/inference#openai-compatibility). + +There is also an async version of the client, `AsyncInferenceClient`, based on `asyncio` and `aiohttp`. You can find docs for it [here](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.AsyncInferenceClient) + +### OpenAI Client + +You can directly use the OpenAI [Python](https://github.com/openai/openai-python) or [JS](https://github.com/openai/openai-node) clients to interact with TGI. + +Install the OpenAI Python package via pip. + +```bash +pip install openai +``` ```python -for token in client.text_generation("How do you make cheese?", max_new_tokens=12, stream=True): - print(token) +from openai import OpenAI + +# init the client but point it to TGI +client = OpenAI( + base_url="http://localhost:8080/v1/", + api_key="-" +) + +chat_completion = client.chat.completions.create( + model="tgi", + messages=[ + {"role": "system", "content": "You are a helpful assistant." }, + {"role": "user", "content": "What is deep learning?"} + ], + stream=True +) + +# iterate and print stream +for message in chat_completion: + print(message) ``` -Another parameter you can use with TGI backend is `details`. You can get more details on generation (tokens, probabilities, etc.) by setting `details` to `True`. When it's specified, TGI will return a `TextGenerationResponse` or `TextGenerationStreamResponse` rather than a string or stream. +## UI -```python -output = client.text_generation(prompt="Meaning of life is", details=True) -print(output) - -# TextGenerationResponse(generated_text=' a complex concept that is not always clear to the individual. It is a concept that is not always', details=Details(finish_reason=, generated_tokens=20, seed=None, prefill=[], tokens=[Token(id=267, text=' a', logprob=-2.0723474, special=False), Token(id=11235, text=' complex', logprob=-3.1272552, special=False), Token(id=17908, text=' concept', logprob=-1.3632495, special=False),..)) -``` - -You can see how to stream below. - -```python -output = client.text_generation(prompt="Meaning of life is", stream=True, details=True) -print(next(iter(output))) - -# TextGenerationStreamResponse(token=Token(id=267, text=' a', logprob=-2.0723474, special=False), generated_text=None, details=None) -``` - -You can check out the details of the function [here](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation). There is also an async version of the client, `AsyncInferenceClient`, based on `asyncio` and `aiohttp`. You can find docs for it [here](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.AsyncInferenceClient) - - -## ChatUI - -ChatUI is an open-source interface built for LLM serving. It offers many customization options, such as web search with SERP API and more. ChatUI can automatically consume the TGI server and even provides an option to switch between different TGI endpoints. You can try it out at [Hugging Chat](https://huggingface.co/chat/), or use the [ChatUI Docker Space](https://huggingface.co/new-space?template=huggingchat/chat-ui-template) to deploy your own Hugging Chat to Spaces. - -To serve both ChatUI and TGI in same environment, simply add your own endpoints to the `MODELS` variable in `.env.local` file inside the `chat-ui` repository. Provide the endpoints pointing to where TGI is served. - -``` -{ -// rest of the model config here -"endpoints": [{"url": "https://HOST:PORT/generate_stream"}] -} -``` - -![ChatUI](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/chatui_screen.png) - -## Gradio +### Gradio Gradio is a Python library that helps you build web applications for your machine learning models with a few lines of code. It has a `ChatInterface` wrapper that helps create neat UIs for chatbots. Let's take a look at how to create a chatbot with streaming mode using TGI and Gradio. Let's install Gradio and Hub Python library first. @@ -89,19 +133,28 @@ Assume you are serving your model on port 8080, we will query through [Inference import gradio as gr from huggingface_hub import InferenceClient -client = InferenceClient(model="http://127.0.0.1:8080") +client = InferenceClient(base_url="http://127.0.0.1:8080") def inference(message, history): partial_message = "" - for token in client.text_generation(message, max_new_tokens=20, stream=True): - partial_message += token + output = client.chat.completions.create( + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": message}, + ], + stream=True, + max_tokens=1024, + ) + + for chunk in output: + partial_message += chunk.choices[0].delta.content yield partial_message gr.ChatInterface( inference, chatbot=gr.Chatbot(height=300), textbox=gr.Textbox(placeholder="Chat with me!", container=False, scale=7), - description="This is the demo for Gradio UI consuming TGI endpoint with LLaMA 7B-Chat model.", + description="This is the demo for Gradio UI consuming TGI endpoint.", title="Gradio 🤝 TGI", examples=["Are tomatoes vegetables?"], retry_btn="Retry", @@ -110,20 +163,7 @@ gr.ChatInterface( ).queue().launch() ``` -The UI looks like this 👇 - -
- - -
- -You can try the demo directly here 👇 +You can check out the UI and try the demo directly here 👇