mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 20:12:07 +00:00
Merge pull request #225 from yuanwu2017/2.3.0
This commit is contained in:
commit
5291f652a1
0
.devcontainer/Dockerfile.trtllm
Normal file
0
.devcontainer/Dockerfile.trtllm
Normal file
0
.devcontainer/devcontainer.json
Normal file
0
.devcontainer/devcontainer.json
Normal file
@ -2,3 +2,6 @@ aml
|
||||
target
|
||||
server/transformers
|
||||
server/flash-attention
|
||||
cmake-build-debug/
|
||||
cmake-build-release/
|
||||
Dockerfile*
|
||||
|
9
.gitignore
vendored
9
.gitignore
vendored
@ -3,9 +3,12 @@ target
|
||||
router/tokenizer.json
|
||||
*__pycache__*
|
||||
|
||||
backends/v2/src/client/pb
|
||||
backends/v3/src/client/pb
|
||||
|
||||
# ROCm auto-generated files
|
||||
*.hip
|
||||
server/exllamav2_kernels/exllamav2_kernels/hip/
|
||||
server/exllamav2
|
||||
server/exllama_kernels/exllama_kernels/hip/
|
||||
server/exllama_kernels/exllama_kernels/hip_func/
|
||||
*_hip.cuh
|
||||
@ -14,3 +17,7 @@ server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
|
||||
|
||||
data/
|
||||
load_tests/*.json
|
||||
server/fbgemmm
|
||||
|
||||
.direnv/
|
||||
.venv/
|
||||
|
@ -5,7 +5,7 @@ repos:
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
exclude: docs/source/basic_tutorials/launcher.md
|
||||
exclude: docs/source/reference/launcher.md
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 24.2.0
|
||||
hooks:
|
||||
@ -13,6 +13,11 @@ repos:
|
||||
- repo: https://github.com/doublify/pre-commit-rust
|
||||
rev: v1.0
|
||||
hooks:
|
||||
- id: fmt
|
||||
- id: cargo-check
|
||||
- id: fmt
|
||||
- id: clippy
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.3.0
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
|
82
.redocly.lint-ignore.yaml
Normal file
82
.redocly.lint-ignore.yaml
Normal file
@ -0,0 +1,82 @@
|
||||
# This file instructs Redocly's linter to ignore the rules contained for specific parts of your API.
|
||||
# See https://redoc.ly/docs/cli/ for more information.
|
||||
docs/openapi.json:
|
||||
no-empty-servers:
|
||||
- '#/openapi'
|
||||
spec:
|
||||
- >-
|
||||
#/components/schemas/GenerateParameters/properties/best_of/exclusiveMinimum
|
||||
- >-
|
||||
#/components/schemas/GenerateParameters/properties/frequency_penalty/exclusiveMinimum
|
||||
- '#/components/schemas/GenerateParameters/properties/grammar/nullable'
|
||||
- >-
|
||||
#/components/schemas/GenerateParameters/properties/repetition_penalty/exclusiveMinimum
|
||||
- '#/components/schemas/GenerateParameters/properties/seed/exclusiveMinimum'
|
||||
- >-
|
||||
#/components/schemas/GenerateParameters/properties/temperature/exclusiveMinimum
|
||||
- '#/components/schemas/GenerateParameters/properties/top_k/exclusiveMinimum'
|
||||
- >-
|
||||
#/components/schemas/GenerateParameters/properties/top_n_tokens/exclusiveMinimum
|
||||
- '#/components/schemas/GenerateParameters/properties/top_p/exclusiveMinimum'
|
||||
- >-
|
||||
#/components/schemas/GenerateParameters/properties/typical_p/exclusiveMinimum
|
||||
- '#/components/schemas/GenerateResponse/properties/details/nullable'
|
||||
- '#/components/schemas/StreamResponse/properties/details/nullable'
|
||||
- '#/components/schemas/ChatRequest/properties/response_format/nullable'
|
||||
- '#/components/schemas/ChatRequest/properties/stream_options/nullable'
|
||||
- '#/components/schemas/ChatRequest/properties/tool_choice/nullable'
|
||||
- '#/components/schemas/ToolChoice/nullable'
|
||||
- '#/components/schemas/ChatCompletionComplete/properties/logprobs/nullable'
|
||||
- '#/components/schemas/ChatCompletionChunk/properties/usage/nullable'
|
||||
- '#/components/schemas/ChatCompletionChoice/properties/logprobs/nullable'
|
||||
no-invalid-media-type-examples:
|
||||
- '#/paths/~1/post/responses/422/content/application~1json/example'
|
||||
- '#/paths/~1/post/responses/424/content/application~1json/example'
|
||||
- '#/paths/~1/post/responses/429/content/application~1json/example'
|
||||
- '#/paths/~1/post/responses/500/content/application~1json/example'
|
||||
- '#/paths/~1generate/post/responses/422/content/application~1json/example'
|
||||
- '#/paths/~1generate/post/responses/424/content/application~1json/example'
|
||||
- '#/paths/~1generate/post/responses/429/content/application~1json/example'
|
||||
- '#/paths/~1generate/post/responses/500/content/application~1json/example'
|
||||
- >-
|
||||
#/paths/~1generate_stream/post/responses/422/content/text~1event-stream/example
|
||||
- >-
|
||||
#/paths/~1generate_stream/post/responses/424/content/text~1event-stream/example
|
||||
- >-
|
||||
#/paths/~1generate_stream/post/responses/429/content/text~1event-stream/example
|
||||
- >-
|
||||
#/paths/~1generate_stream/post/responses/500/content/text~1event-stream/example
|
||||
- '#/paths/~1tokenize/post/responses/404/content/application~1json/example'
|
||||
- >-
|
||||
#/paths/~1v1~1chat~1completions/post/responses/422/content/application~1json/example
|
||||
- >-
|
||||
#/paths/~1v1~1chat~1completions/post/responses/424/content/application~1json/example
|
||||
- >-
|
||||
#/paths/~1v1~1chat~1completions/post/responses/429/content/application~1json/example
|
||||
- >-
|
||||
#/paths/~1v1~1chat~1completions/post/responses/500/content/application~1json/example
|
||||
- >-
|
||||
#/paths/~1v1~1completions/post/responses/422/content/application~1json/example
|
||||
- >-
|
||||
#/paths/~1v1~1completions/post/responses/424/content/application~1json/example
|
||||
- >-
|
||||
#/paths/~1v1~1completions/post/responses/429/content/application~1json/example
|
||||
- >-
|
||||
#/paths/~1v1~1completions/post/responses/500/content/application~1json/example
|
||||
operation-4xx-response:
|
||||
- '#/paths/~1health/get/responses'
|
||||
- '#/paths/~1info/get/responses'
|
||||
- '#/paths/~1metrics/get/responses'
|
||||
no-unused-components:
|
||||
- '#/components/schemas/Completion'
|
||||
security-defined:
|
||||
- '#/paths/~1/post'
|
||||
- '#/paths/~1generate/post'
|
||||
- '#/paths/~1generate_stream/post'
|
||||
- '#/paths/~1health/get'
|
||||
- '#/paths/~1info/get'
|
||||
- '#/paths/~1metrics/get'
|
||||
- '#/paths/~1tokenize/post'
|
||||
- '#/paths/~1v1~1chat~1completions/post'
|
||||
- '#/paths/~1v1~1completions/post'
|
||||
- '#/paths/~1v1~1models/get'
|
133
CODE_OF_CONDUCT.md
Normal file
133
CODE_OF_CONDUCT.md
Normal file
@ -0,0 +1,133 @@
|
||||
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
We as members, contributors, and leaders pledge to make participation in our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||
identity and expression, level of experience, education, socio-economic status,
|
||||
nationality, personal appearance, race, caste, color, religion, or sexual
|
||||
identity and orientation.
|
||||
|
||||
We pledge to act and interact in ways that contribute to an open, welcoming,
|
||||
diverse, inclusive, and healthy community.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to a positive environment for our
|
||||
community include:
|
||||
|
||||
* Demonstrating empathy and kindness toward other people
|
||||
* Being respectful of differing opinions, viewpoints, and experiences
|
||||
* Giving and gracefully accepting constructive feedback
|
||||
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||
and learning from the experience
|
||||
* Focusing on what is best not just for us as individuals, but for the overall
|
||||
community
|
||||
|
||||
Examples of unacceptable behavior include:
|
||||
|
||||
* The use of sexualized language or imagery, and sexual attention or advances of
|
||||
any kind
|
||||
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or email address,
|
||||
without their explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Enforcement Responsibilities
|
||||
|
||||
Community leaders are responsible for clarifying and enforcing our standards of
|
||||
acceptable behavior and will take appropriate and fair corrective action in
|
||||
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||
or harmful.
|
||||
|
||||
Community leaders have the right and responsibility to remove, edit, or reject
|
||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||
decisions when appropriate.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all community spaces, and also applies when
|
||||
an individual is officially representing the community in public spaces.
|
||||
Examples of representing our community include using an official e-mail address,
|
||||
posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported to the community leaders responsible for enforcement at
|
||||
feedback@huggingface.co.
|
||||
All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
reporter of any incident.
|
||||
|
||||
## Enforcement Guidelines
|
||||
|
||||
Community leaders will follow these Community Impact Guidelines in determining
|
||||
the consequences for any action they deem in violation of this Code of Conduct:
|
||||
|
||||
### 1. Correction
|
||||
|
||||
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||
unprofessional or unwelcome in the community.
|
||||
|
||||
**Consequence**: A private, written warning from community leaders, providing
|
||||
clarity around the nature of the violation and an explanation of why the
|
||||
behavior was inappropriate. A public apology may be requested.
|
||||
|
||||
### 2. Warning
|
||||
|
||||
**Community Impact**: A violation through a single incident or series of
|
||||
actions.
|
||||
|
||||
**Consequence**: A warning with consequences for continued behavior. No
|
||||
interaction with the people involved, including unsolicited interaction with
|
||||
those enforcing the Code of Conduct, for a specified period of time. This
|
||||
includes avoiding interactions in community spaces as well as external channels
|
||||
like social media. Violating these terms may lead to a temporary or permanent
|
||||
ban.
|
||||
|
||||
### 3. Temporary Ban
|
||||
|
||||
**Community Impact**: A serious violation of community standards, including
|
||||
sustained inappropriate behavior.
|
||||
|
||||
**Consequence**: A temporary ban from any sort of interaction or public
|
||||
communication with the community for a specified period of time. No public or
|
||||
private interaction with the people involved, including unsolicited interaction
|
||||
with those enforcing the Code of Conduct, is allowed during this period.
|
||||
Violating these terms may lead to a permanent ban.
|
||||
|
||||
### 4. Permanent Ban
|
||||
|
||||
**Community Impact**: Demonstrating a pattern of violation of community
|
||||
standards, including sustained inappropriate behavior, harassment of an
|
||||
individual, or aggression toward or disparagement of classes of individuals.
|
||||
|
||||
**Consequence**: A permanent ban from any sort of public interaction within the
|
||||
community.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
||||
version 2.1, available at
|
||||
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
|
||||
|
||||
Community Impact Guidelines were inspired by
|
||||
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
|
||||
|
||||
For answers to common questions about this code of conduct, see the FAQ at
|
||||
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
|
||||
[https://www.contributor-covenant.org/translations][translations].
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
|
||||
[Mozilla CoC]: https://github.com/mozilla/diversity
|
||||
[FAQ]: https://www.contributor-covenant.org/faq
|
||||
[translations]: https://www.contributor-covenant.org/translations
|
120
CONTRIBUTING.md
Normal file
120
CONTRIBUTING.md
Normal file
@ -0,0 +1,120 @@
|
||||
<!---
|
||||
Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
# Contribute to text-generation-inference
|
||||
|
||||
Everyone is welcome to contribute, and we value everybody's contribution. Code
|
||||
contributions are not the only way to help the community. Answering questions, helping
|
||||
others, and improving the documentation are also immensely valuable.
|
||||
|
||||
It also helps us if you spread the word! Reference the library in blog posts
|
||||
about the awesome projects it made possible, shout out on Twitter every time it has
|
||||
helped you, or simply ⭐️ the repository to say thank you.
|
||||
|
||||
However you choose to contribute, please be mindful and respect our
|
||||
[code of conduct](https://github.com/huggingface/text-generation-inference/blob/main/CODE_OF_CONDUCT.md).
|
||||
|
||||
**This guide was heavily inspired by the awesome [scikit-learn guide to contributing](https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md).**
|
||||
|
||||
## Ways to contribute
|
||||
|
||||
There are several ways you can contribute to text-generation-inference.
|
||||
|
||||
* Fix outstanding issues with the existing code.
|
||||
* Submit issues related to bugs or desired new features.
|
||||
* Contribute to the examples or to the documentation.
|
||||
|
||||
> All contributions are equally valuable to the community. 🥰
|
||||
|
||||
## Fixing outstanding issues
|
||||
|
||||
If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request) and open
|
||||
a Pull Request!
|
||||
|
||||
## Submitting a bug-related issue or feature request
|
||||
|
||||
Do your best to follow these guidelines when submitting a bug-related issue or a feature
|
||||
request. It will make it easier for us to come back to you quickly and with good
|
||||
feedback.
|
||||
|
||||
### Did you find a bug?
|
||||
|
||||
The text-generation-inference library is robust and reliable thanks to users who report the problems they encounter.
|
||||
|
||||
Before you report an issue, we would really appreciate it if you could **make sure the bug was not
|
||||
already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the
|
||||
library itself, and not your code.
|
||||
|
||||
Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so
|
||||
we can quickly resolve it:
|
||||
|
||||
* Your **OS type and version**, as well as your environment versions (versions of rust, python, and dependencies).
|
||||
* A short, self-contained, code snippet that allows us to reproduce the bug.
|
||||
* The *full* traceback if an exception is raised.
|
||||
* Attach any other additional information, like screenshots, you think may help.
|
||||
|
||||
To get the OS and software versions automatically, you can re-run the launcher with the `--env` flag:
|
||||
|
||||
```bash
|
||||
text-generation-launcher --env
|
||||
```
|
||||
|
||||
This will precede the launch of the model with the information relative to your environment. We recommend pasting
|
||||
that in your issue report.
|
||||
|
||||
### Do you want a new feature?
|
||||
|
||||
If there is a new feature you'd like to see in text-generation-inference, please open an issue and describe:
|
||||
|
||||
1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it
|
||||
a feature related to something you need for a project? Is it something you worked on and think it could benefit
|
||||
the community?
|
||||
|
||||
Whatever it is, we'd love to hear about it!
|
||||
|
||||
2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better
|
||||
we'll be able to help you.
|
||||
3. Provide a *code snippet* that demonstrates the feature's usage.
|
||||
4. If the feature is related to a paper, please include a link.
|
||||
|
||||
If your issue is well written we're already 80% of the way there by the time you create it.
|
||||
|
||||
We have added [templates](https://github.com/huggingface/text-generation-inference/tree/main/.github/ISSUE_TEMPLATE)
|
||||
to help you get started with your issue.
|
||||
|
||||
## Do you want to implement a new model?
|
||||
|
||||
New models are constantly released and if you want to implement a new model, please provide the following information:
|
||||
|
||||
* A short description of the model and a link to the paper.
|
||||
* Link to the implementation if it is open-sourced.
|
||||
* Link to the model weights if they are available.
|
||||
|
||||
If you are willing to contribute the model yourself, let us know so we can help you add it to text-generation-inference!
|
||||
|
||||
## Do you want to add documentation?
|
||||
|
||||
We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know
|
||||
how the documentation can be improved such as typos and any content that is missing, unclear or inaccurate. We'll be
|
||||
happy to make the changes or help you make a contribution if you're interested!
|
||||
|
||||
## I want to become a maintainer of the project. How do I get there?
|
||||
|
||||
TGI is a project led and managed by Hugging Face as it powers our internal services. However, we are happy to have
|
||||
motivated individuals from other organizations join us as maintainers with the goal of making TGI the best inference
|
||||
service.
|
||||
|
||||
If you are such an individual (or organization), please reach out to us and let's collaborate.
|
2443
Cargo.lock
generated
2443
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
40
Cargo.toml
40
Cargo.toml
@ -1,27 +1,53 @@
|
||||
[workspace]
|
||||
members = [
|
||||
"benchmark",
|
||||
"router",
|
||||
"router/client",
|
||||
"router/grpc-metadata",
|
||||
"launcher"
|
||||
"benchmark",
|
||||
"backends/v2",
|
||||
"backends/v3",
|
||||
"backends/grpc-metadata",
|
||||
"backends/trtllm",
|
||||
"launcher",
|
||||
"router"
|
||||
]
|
||||
default-members = [
|
||||
"benchmark",
|
||||
"backends/v2",
|
||||
"backends/v3",
|
||||
"backends/grpc-metadata",
|
||||
# "backends/trtllm",
|
||||
"launcher",
|
||||
"router"
|
||||
]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "2.0.4"
|
||||
version = "2.3.1"
|
||||
edition = "2021"
|
||||
authors = ["Olivier Dehaene"]
|
||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||
|
||||
[workspace.dependencies]
|
||||
base64 = "0.22.0"
|
||||
tokenizers = { version = "0.20.0", features = ["http"] }
|
||||
hf-hub = { version = "0.3.1", features = ["tokio"] }
|
||||
metrics = { version = "0.23.0" }
|
||||
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
||||
minijinja = { version = "2.2.0", features = ["json"] }
|
||||
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
|
||||
|
||||
[profile.release]
|
||||
incremental = true
|
||||
|
||||
[profile.release-binary]
|
||||
inherits = "release"
|
||||
debug = 1
|
||||
incremental = true
|
||||
panic = "abort"
|
||||
|
||||
[profile.release-opt]
|
||||
inherits = "release"
|
||||
debug = 0
|
||||
incremental = false
|
||||
lto = "fat"
|
||||
opt-level = 3
|
||||
codegen-units = 1
|
||||
panic = "abort"
|
||||
|
48
Dockerfile
48
Dockerfile
@ -1,18 +1,24 @@
|
||||
# Rust builder
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.78 AS chef
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
|
||||
WORKDIR /usr/src
|
||||
|
||||
FROM chef as planner
|
||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||
|
||||
FROM chef AS planner
|
||||
COPY Cargo.lock Cargo.lock
|
||||
COPY Cargo.toml Cargo.toml
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
COPY proto proto
|
||||
COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY backends backends
|
||||
COPY launcher launcher
|
||||
RUN cargo chef prepare --recipe-path recipe.json
|
||||
|
||||
FROM chef AS builder
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
python3.11-dev
|
||||
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||
@ -20,22 +26,29 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||
rm -f $PROTOC_ZIP
|
||||
|
||||
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||
COPY Cargo.lock Cargo.lock
|
||||
RUN cargo chef cook --release --recipe-path recipe.json
|
||||
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
||||
|
||||
ARG GIT_SHA
|
||||
ARG DOCKER_LABEL
|
||||
|
||||
COPY Cargo.toml Cargo.toml
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
COPY proto proto
|
||||
COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY backends backends
|
||||
COPY launcher launcher
|
||||
RUN cargo build --release
|
||||
RUN cargo build --profile release-opt
|
||||
|
||||
# Text Generation Inference base image
|
||||
FROM vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest as base
|
||||
|
||||
ENV ATTENTION=default
|
||||
ENV PREFIX_CACHING=0
|
||||
ENV PREFILL_CHUNKING=0
|
||||
|
||||
# Text Generation Inference base env
|
||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||
ENV HF_HOME=/data \
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||
PORT=80
|
||||
|
||||
@ -51,6 +64,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
||||
make \
|
||||
curl \
|
||||
git \
|
||||
python3.11-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install server
|
||||
@ -59,21 +73,33 @@ COPY server server
|
||||
COPY server/Makefile server/Makefile
|
||||
RUN cd server && \
|
||||
make gen-server && \
|
||||
pip install -r requirements.txt && \
|
||||
pip install --no-deps -r requirements.txt && \
|
||||
bash ./dill-0.3.8-patch.sh && \
|
||||
pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.18.0 && \
|
||||
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
|
||||
pip install . --no-cache-dir
|
||||
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
# Install router
|
||||
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
|
||||
# Install launcher
|
||||
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
|
||||
|
||||
# AWS Sagemaker compatible image
|
||||
FROM base AS sagemaker
|
||||
|
||||
COPY sagemaker-entrypoint.sh entrypoint.sh
|
||||
RUN chmod +x entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["./entrypoint.sh"]
|
||||
|
||||
# Final image
|
||||
FROM base
|
||||
|
||||
ENTRYPOINT ["text-generation-launcher"]
|
||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||
RUN chmod +x /tgi-entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||
CMD ["--json-output"]
|
||||
|
24
Dockerfile.nix
Normal file
24
Dockerfile.nix
Normal file
@ -0,0 +1,24 @@
|
||||
# Build the image and get out the docker file:
|
||||
#
|
||||
# docker build -t tgi-nix-builder -f Dockerfile.nix
|
||||
# docker run --log-driver=none tgi-nix-builder | docker load
|
||||
|
||||
FROM nixos/nix:2.18.8 AS builder
|
||||
RUN echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf
|
||||
RUN nix profile install nixpkgs#cachix
|
||||
RUN cachix use text-generation-inference
|
||||
WORKDIR /root
|
||||
ADD . .
|
||||
RUN nix build .
|
||||
RUN mkdir /tmp/nix-store-closure
|
||||
RUN cp -R $(nix-store -qR result/) /tmp/nix-store-closure
|
||||
|
||||
FROM ubuntu:24.04
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy /nix/store
|
||||
COPY --from=builder /tmp/nix-store-closure /nix/store
|
||||
COPY --from=builder /root/result /app
|
||||
RUN ldconfig
|
||||
CMD ["ldconfig", "/app/bin/text-generation-launcher"]
|
23
Dockerfile.trtllm
Normal file
23
Dockerfile.trtllm
Normal file
@ -0,0 +1,23 @@
|
||||
# All the tooling for CUDA
|
||||
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 AS cuda-builder
|
||||
|
||||
WORKDIR /usr/src/tgi/backends/trtllm
|
||||
RUN apt update && apt install -y cmake git git-lfs gcc g++ ninja-build libopenmpi-dev python3-dev python3-pip wget
|
||||
|
||||
COPY . /usr/src/tgi
|
||||
RUN chmod +x scripts/install_tensorrt.sh && scripts/install_tensorrt.sh
|
||||
RUN cmake -G Ninja -B build -DTRT_LIB_DIR=/usr/local/tensorrt/lib -DTRT_INCLUDE_DIR=/usr/local/tensorrt/include .
|
||||
RUN cmake --build build --parallel -t tgi_trtllm_backend_impl
|
||||
|
||||
# All the tooling for Rust
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
||||
WORKDIR /usr/src
|
||||
|
||||
# Include CUDA related libraries and tools to the Rust based image
|
||||
COPY --from=cuda-builder /usr/local/cuda /usr/local/cuda
|
||||
COPY --from=cuda-builder /usr/local/tensorrt /usr/local/tensorrt
|
||||
COPY --from=cuda-builder /usr/src/tgi/backends/trtllm/build /usr/local/tgi/trtllm/build
|
||||
ENV PATH=/usr/local/cuda/bin:$PATH
|
||||
ENV LD_LIBRARY_PATH=/usr/local/tensorrt/lib:$LD_LIBRARY_PATH
|
||||
|
||||
RUN apt update && apt install -y cmake git gcc g++ ninja-build libopenmpi3
|
233
Dockerfile_amd
233
Dockerfile_amd
@ -1,23 +1,24 @@
|
||||
# Rust builder
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.78 AS chef
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
|
||||
WORKDIR /usr/src
|
||||
|
||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||
|
||||
FROM chef as planner
|
||||
FROM chef AS planner
|
||||
COPY Cargo.lock Cargo.lock
|
||||
COPY Cargo.toml Cargo.toml
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
COPY proto proto
|
||||
COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY backends backends
|
||||
COPY launcher launcher
|
||||
RUN cargo chef prepare --recipe-path recipe.json
|
||||
|
||||
FROM chef AS builder
|
||||
|
||||
ARG GIT_SHA
|
||||
ARG DOCKER_LABEL
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
python3.11-dev
|
||||
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||
@ -25,18 +26,22 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||
rm -f $PROTOC_ZIP
|
||||
|
||||
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||
RUN cargo chef cook --release --recipe-path recipe.json
|
||||
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
||||
|
||||
ARG GIT_SHA
|
||||
ARG DOCKER_LABEL
|
||||
|
||||
COPY Cargo.toml Cargo.toml
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
COPY proto proto
|
||||
COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY backends backends
|
||||
COPY launcher launcher
|
||||
RUN cargo build --release
|
||||
RUN cargo build --profile release-opt
|
||||
|
||||
# Text Generation Inference base image for RoCm
|
||||
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base
|
||||
FROM rocm/dev-ubuntu-22.04:6.2 AS base
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
@ -45,33 +50,34 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
||||
curl \
|
||||
git \
|
||||
make \
|
||||
libmsgpack-dev \
|
||||
libssl-dev \
|
||||
llvm-dev \
|
||||
g++ \
|
||||
# Needed to build VLLM & flash.
|
||||
rocthrust-dev \
|
||||
hipsparse-dev \
|
||||
hipblas-dev \
|
||||
hipblaslt-dev \
|
||||
hipcub-dev \
|
||||
rocblas-dev \
|
||||
hiprand-dev \
|
||||
hipfft-dev \
|
||||
rocrand-dev \
|
||||
miopen-hip-dev \
|
||||
hipfft-dev \
|
||||
hipcub-dev \
|
||||
hipsolver-dev \
|
||||
rccl-dev \
|
||||
cmake \
|
||||
python3-dev && \
|
||||
python3.11-venv && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Keep in sync with `server/pyproject.toml
|
||||
ARG MAMBA_VERSION=23.1.0-1
|
||||
ARG PYTORCH_VERSION='2.3.0'
|
||||
ARG ROCM_VERSION='6.0.2'
|
||||
ARG PYTHON_VERSION='3.10.10'
|
||||
ARG PYTHON_VERSION='3.11.10'
|
||||
# Automatically set by buildx
|
||||
ARG TARGETPLATFORM
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
ENV PATH=/opt/conda/bin:$PATH
|
||||
|
||||
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
||||
|
||||
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||
# Install mamba
|
||||
@ -86,42 +92,141 @@ RUN chmod +x ~/mambaforge.sh && \
|
||||
mamba init && \
|
||||
rm ~/mambaforge.sh
|
||||
|
||||
# RUN conda install intel::mkl-static intel::mkl-include
|
||||
# Install pytorch
|
||||
# On arm64 we exit with an error code
|
||||
RUN case ${TARGETPLATFORM} in \
|
||||
"linux/arm64") exit 1 ;; \
|
||||
*) /opt/conda/bin/conda update -y conda && \
|
||||
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
|
||||
esac && \
|
||||
/opt/conda/bin/conda clean -ya
|
||||
|
||||
# Install flash-attention, torch dependencies
|
||||
RUN pip install numpy einops ninja --no-cache-dir
|
||||
RUN python3 -m pip install --upgrade pip && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN conda install intel::mkl-static intel::mkl-include
|
||||
RUN pip uninstall -y triton && \
|
||||
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
|
||||
cd triton/python && \
|
||||
pip install .
|
||||
RUN conda install mkl=2021
|
||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/
|
||||
|
||||
RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir
|
||||
|
||||
ARG _GLIBCXX_USE_CXX11_ABI="1"
|
||||
ARG CMAKE_PREFIX_PATH="/opt/conda"
|
||||
ARG COMMON_WORKDIR=/
|
||||
WORKDIR ${COMMON_WORKDIR}
|
||||
|
||||
|
||||
# Install HIPBLASLt
|
||||
FROM base AS build_hipblaslt
|
||||
ARG HIPBLASLT_BRANCH="e6da924"
|
||||
RUN git clone https://github.com/ROCm/hipBLASLt.git \
|
||||
&& cd hipBLASLt \
|
||||
&& git checkout ${HIPBLASLT_BRANCH} \
|
||||
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \
|
||||
&& cd build/release \
|
||||
&& make package
|
||||
|
||||
FROM scratch AS export_hipblaslt
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb /
|
||||
|
||||
# RCCL build stages
|
||||
FROM base AS build_rccl
|
||||
ARG RCCL_BRANCH="rocm-6.2.0"
|
||||
RUN git clone https://github.com/ROCm/rccl \
|
||||
&& cd rccl \
|
||||
&& git checkout ${RCCL_BRANCH} \
|
||||
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
|
||||
FROM scratch AS export_rccl
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb /
|
||||
|
||||
# Triton build stages
|
||||
FROM base AS build_triton
|
||||
ARG TRITON_BRANCH="e192dba"
|
||||
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
|
||||
RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \
|
||||
&& cd triton \
|
||||
&& git checkout ${TRITON_BRANCH} \
|
||||
&& cd python \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist
|
||||
FROM scratch AS export_triton
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl /
|
||||
|
||||
# # AMD-SMI build stages
|
||||
FROM base AS build_amdsmi
|
||||
RUN cd /opt/rocm/share/amd_smi \
|
||||
&& pip wheel . --wheel-dir=dist
|
||||
FROM scratch AS export_amdsmi
|
||||
COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl /
|
||||
|
||||
|
||||
FROM base as build_pytorch
|
||||
|
||||
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
|
||||
if ls /install/*.deb; then \
|
||||
dpkg -i /install/*.deb \
|
||||
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
|
||||
fi
|
||||
|
||||
ARG BUILD_ENVIRONMENT=pytorch-linux-jammy-rocm6.2-py3.11
|
||||
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
||||
ARG BUILD_CAFFE2="0" \
|
||||
BUILD_CAFFE2_OPS="0" \
|
||||
USE_CUDA="0" \
|
||||
USE_ROCM="1" \
|
||||
BUILD_TEST="0" \
|
||||
USE_FBGEMM="0" \
|
||||
USE_NNPACK="0" \
|
||||
USE_QNNPACK="0" \
|
||||
USE_XNNPACK="0" \
|
||||
USE_FLASH_ATTENTION="1" \
|
||||
USE_MEM_EFF_ATTENTION="0"
|
||||
|
||||
RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install
|
||||
# A commit to fix the output scaling factor issue in _scaled_mm
|
||||
# Not yet in 2.5.0-rc1
|
||||
ARG PYTORCH_BRANCH="cedc116"
|
||||
ARG PYTORCH_VISION_BRANCH="v0.19.1"
|
||||
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
|
||||
|
||||
# Set as recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
|
||||
ENV HIP_FORCE_DEV_KERNARG=1
|
||||
RUN git clone ${PYTORCH_REPO} pytorch \
|
||||
&& cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \
|
||||
&& pip install -r requirements.txt --no-cache-dir \
|
||||
&& python tools/amd_build/build_amd.py \
|
||||
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist
|
||||
FROM scratch as export_pytorch
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl /
|
||||
|
||||
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
|
||||
# However, Triton requires a tunning for each prompt length, which is prohibitive.
|
||||
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
|
||||
FROM base AS install_deps
|
||||
|
||||
FROM base AS kernel-builder
|
||||
ARG COMMON_WORKDIR
|
||||
|
||||
# Install hipblaslt
|
||||
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
|
||||
if ls /install/*.deb; then \
|
||||
dpkg -i /install/*.deb \
|
||||
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
|
||||
fi
|
||||
|
||||
RUN --mount=type=bind,from=export_rccl,src=/,target=/install \
|
||||
if ls /install/*.deb; then \
|
||||
dpkg -i /install/*.deb \
|
||||
# RCCL needs to be installed twice
|
||||
&& dpkg -i /install/*.deb \
|
||||
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \
|
||||
fi
|
||||
|
||||
RUN --mount=type=bind,from=export_triton,src=/,target=/install \
|
||||
if ls /install/*.whl; then \
|
||||
# Preemptively uninstall to prevent pip same-version no-installs
|
||||
pip uninstall -y triton \
|
||||
&& pip install /install/*.whl; \
|
||||
fi
|
||||
|
||||
RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \
|
||||
# Preemptively uninstall to prevent pip same-version no-installs
|
||||
pip uninstall -y amdsmi \
|
||||
&& pip install /install/*.whl;
|
||||
|
||||
RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
|
||||
if ls /install/*.whl; then \
|
||||
# Preemptively uninstall to prevent pip same-version no-installs
|
||||
pip uninstall -y torch torchvision \
|
||||
&& pip install /install/*.whl; \
|
||||
fi
|
||||
|
||||
FROM install_deps AS kernel-builder
|
||||
|
||||
# # Build vllm kernels
|
||||
FROM kernel-builder AS vllm-builder
|
||||
@ -142,46 +247,46 @@ COPY server/Makefile-flash-att-v2 Makefile
|
||||
RUN make build-flash-attention-v2-rocm
|
||||
|
||||
# Build Transformers CUDA kernels (gpt-neox and bloom)
|
||||
FROM kernel-builder as custom-kernels-builder
|
||||
FROM kernel-builder AS custom-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/custom_kernels/ .
|
||||
RUN python setup.py build
|
||||
|
||||
# Build exllama kernels
|
||||
FROM kernel-builder as exllama-kernels-builder
|
||||
FROM kernel-builder AS exllama-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/exllama_kernels/ .
|
||||
|
||||
RUN python setup.py build
|
||||
|
||||
# Build exllama v2 kernels
|
||||
FROM kernel-builder as exllamav2-kernels-builder
|
||||
FROM kernel-builder AS exllamav2-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/exllamav2_kernels/ .
|
||||
|
||||
RUN python setup.py build
|
||||
|
||||
FROM base as base-copy
|
||||
FROM install_deps AS base-copy
|
||||
|
||||
# Text Generation Inference base env
|
||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||
ENV HF_HOME=/data \
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||
PORT=80
|
||||
|
||||
# Copy builds artifacts from vllm builder
|
||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
|
||||
# Copy build artifacts from flash attention v2 builder
|
||||
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
|
||||
# Copy build artifacts from custom kernels builder
|
||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
|
||||
# Copy build artifacts from exllama kernels builder
|
||||
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
|
||||
# Copy build artifacts from exllamav2 kernels builder
|
||||
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
|
||||
# Install server
|
||||
COPY proto proto
|
||||
@ -193,14 +298,15 @@ RUN cd server && \
|
||||
pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
||||
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
# Install router
|
||||
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
|
||||
# Install launcher
|
||||
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
|
||||
|
||||
# AWS Sagemaker compatible image
|
||||
FROM base as sagemaker
|
||||
FROM base AS sagemaker
|
||||
|
||||
COPY sagemaker-entrypoint.sh entrypoint.sh
|
||||
RUN chmod +x entrypoint.sh
|
||||
@ -210,6 +316,19 @@ ENTRYPOINT ["./entrypoint.sh"]
|
||||
# Final image
|
||||
FROM base-copy
|
||||
|
||||
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
|
||||
ENV HIP_FORCE_DEV_KERNARG=1
|
||||
|
||||
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
|
||||
# However, Triton requires a tunning for each prompt length, which is prohibitive.
|
||||
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
|
||||
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
|
||||
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
|
||||
ENV VLLM_MOE_PADDING=0
|
||||
ENV ATTENTION=paged
|
||||
ENV USE_PREFIX_CACHING=0
|
||||
ENV ROCM_USE_SKINNY_GEMM=1
|
||||
|
||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||
RUN chmod +x /tgi-entrypoint.sh
|
||||
|
||||
|
159
Dockerfile_intel
159
Dockerfile_intel
@ -1,22 +1,25 @@
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.78 AS chef
|
||||
ARG PLATFORM=xpu
|
||||
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
|
||||
WORKDIR /usr/src
|
||||
|
||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||
|
||||
FROM chef as planner
|
||||
FROM chef AS planner
|
||||
COPY Cargo.lock Cargo.lock
|
||||
COPY Cargo.toml Cargo.toml
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
COPY proto proto
|
||||
COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY backends backends
|
||||
COPY launcher launcher
|
||||
RUN cargo chef prepare --recipe-path recipe.json
|
||||
|
||||
FROM chef AS builder
|
||||
|
||||
ARG GIT_SHA
|
||||
ARG DOCKER_LABEL
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
python3.11-dev
|
||||
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||
@ -24,21 +27,52 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||
rm -f $PROTOC_ZIP
|
||||
|
||||
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||
RUN cargo chef cook --release --recipe-path recipe.json
|
||||
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
||||
|
||||
ARG GIT_SHA
|
||||
ARG DOCKER_LABEL
|
||||
|
||||
COPY Cargo.toml Cargo.toml
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
COPY proto proto
|
||||
COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY backends backends
|
||||
COPY launcher launcher
|
||||
RUN cargo build --release
|
||||
RUN cargo build --profile release-opt
|
||||
|
||||
|
||||
# Text Generation Inference base image for Intel
|
||||
FROM intel/intel-extension-for-pytorch:2.1.30-xpu as base
|
||||
|
||||
FROM intel/intel-extension-for-pytorch:2.3.110-xpu AS xpu
|
||||
|
||||
USER root
|
||||
|
||||
ARG MAMBA_VERSION=23.1.0-1
|
||||
ARG PYTHON_VERSION='3.11.10'
|
||||
# Automatically set by buildx
|
||||
ARG TARGETPLATFORM
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
|
||||
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||
# Install mamba
|
||||
# translating Docker's TARGETPLATFORM into mamba arches
|
||||
RUN case ${TARGETPLATFORM} in \
|
||||
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
||||
*) MAMBA_ARCH=x86_64 ;; \
|
||||
esac && \
|
||||
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
||||
RUN chmod +x ~/mambaforge.sh && \
|
||||
bash ~/mambaforge.sh -b -p /opt/conda && \
|
||||
rm ~/mambaforge.sh
|
||||
|
||||
RUN case ${TARGETPLATFORM} in \
|
||||
"linux/arm64") exit 1 ;; \
|
||||
*) /opt/conda/bin/conda update -y conda && \
|
||||
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
|
||||
esac && \
|
||||
/opt/conda/bin/conda clean -ya
|
||||
|
||||
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
|
||||
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
|
||||
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
|
||||
@ -48,17 +82,16 @@ RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dea
|
||||
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
|
||||
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
|
||||
|
||||
RUN apt-get update && apt install -y intel-basekit xpu-smi
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y intel-basekit xpu-smi cmake ninja-build pciutils
|
||||
|
||||
# Text Generation Inference base env
|
||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||
ENV HF_HOME=/data \
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||
PORT=80
|
||||
|
||||
|
||||
WORKDIR /usr/src
|
||||
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl
|
||||
RUN pip install intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl
|
||||
RUN pip install torch==2.3.1+cxx11.abi torchvision==0.18.1+cxx11.abi torchaudio==2.3.1+cxx11.abi intel-extension-for-pytorch==2.3.110+xpu oneccl_bind_pt==2.3.100+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --no-cache-dir
|
||||
|
||||
# Install server
|
||||
COPY proto proto
|
||||
@ -66,26 +99,112 @@ COPY server server
|
||||
COPY server/Makefile server/Makefile
|
||||
RUN cd server && \
|
||||
make gen-server && \
|
||||
pip install -r requirements_cuda.txt && \
|
||||
pip install -r requirements_intel.txt && \
|
||||
pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
||||
|
||||
ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest
|
||||
ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest
|
||||
ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric
|
||||
ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib
|
||||
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:
|
||||
ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:/opt/conda/lib
|
||||
ENV PATH=/opt/conda/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
||||
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
|
||||
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
|
||||
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
# Install router
|
||||
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
|
||||
# Install launcher
|
||||
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
|
||||
# Final image
|
||||
FROM base
|
||||
|
||||
# Text Generation Inference base image for Intel-cpu
|
||||
FROM ubuntu:22.04 AS cpu
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
ca-certificates \
|
||||
make \
|
||||
g++ \
|
||||
git \
|
||||
wget \
|
||||
cmake \
|
||||
libnuma-dev
|
||||
|
||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||
PORT=80
|
||||
|
||||
ARG MAMBA_VERSION=23.1.0-1
|
||||
ARG PYTHON_VERSION='3.11.10'
|
||||
# Automatically set by buildx
|
||||
ARG TARGETPLATFORM
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
|
||||
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||
# Install mamba
|
||||
# translating Docker's TARGETPLATFORM into mamba arches
|
||||
RUN case ${TARGETPLATFORM} in \
|
||||
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
||||
*) MAMBA_ARCH=x86_64 ;; \
|
||||
esac && \
|
||||
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
||||
RUN chmod +x ~/mambaforge.sh && \
|
||||
bash ~/mambaforge.sh -b -p /opt/conda && \
|
||||
rm ~/mambaforge.sh
|
||||
|
||||
RUN case ${TARGETPLATFORM} in \
|
||||
"linux/arm64") exit 1 ;; \
|
||||
*) /opt/conda/bin/conda update -y conda && \
|
||||
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
|
||||
esac && \
|
||||
/opt/conda/bin/conda clean -ya
|
||||
|
||||
RUN conda install -c conda-forge gperftools mkl
|
||||
|
||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||
RUN pip install triton py-libnuma
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a
|
||||
|
||||
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131
|
||||
|
||||
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
|
||||
|
||||
RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
|
||||
|
||||
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so
|
||||
ENV CCL_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch
|
||||
ENV I_MPI_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch
|
||||
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
|
||||
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/lib
|
||||
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
|
||||
|
||||
# Install server
|
||||
COPY proto proto
|
||||
COPY server server
|
||||
COPY server/Makefile server/Makefile
|
||||
RUN cd server && \
|
||||
make gen-server && \
|
||||
pip install -r requirements_intel.txt && \
|
||||
pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
||||
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
# Install router
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
|
||||
# Install launcher
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
|
||||
FROM ${PLATFORM} AS final
|
||||
ENV ATTENTION=paged
|
||||
ENV USE_PREFIX_CACHING=0
|
||||
ENV CUDA_GRAPHS=0
|
||||
ENTRYPOINT ["text-generation-launcher"]
|
||||
CMD ["--json-output"]
|
||||
|
23
Makefile
23
Makefile
@ -1,20 +1,22 @@
|
||||
install-server:
|
||||
cd server && make install
|
||||
|
||||
install-integration-tests:
|
||||
cd integration-tests && pip install -r requirements.txt
|
||||
cd clients/python && pip install .
|
||||
install-server-cpu:
|
||||
cd server && make install-server
|
||||
|
||||
install-router:
|
||||
cd router && cargo install --locked --path .
|
||||
cargo install --path backends/v3/
|
||||
|
||||
install-launcher:
|
||||
cd launcher && cargo install --locked --path .
|
||||
cargo install --path launcher/
|
||||
|
||||
install-benchmark:
|
||||
cd benchmark && cargo install --locked --path .
|
||||
cargo install --path benchmark/
|
||||
|
||||
install: install-server install-router install-launcher install-custom-kernels
|
||||
install: install-server install-router install-launcher
|
||||
|
||||
|
||||
install-cpu: install-server-cpu install-router install-launcher
|
||||
|
||||
server-dev:
|
||||
cd server && make run-dev
|
||||
@ -25,6 +27,10 @@ router-dev:
|
||||
rust-tests: install-router install-launcher
|
||||
cargo test
|
||||
|
||||
install-integration-tests:
|
||||
cd integration-tests && pip install -r requirements.txt
|
||||
cd clients/python && pip install .
|
||||
|
||||
integration-tests: install-integration-tests
|
||||
pytest -s -vv -m "not private" integration-tests
|
||||
|
||||
@ -44,6 +50,3 @@ run-falcon-7b-instruct:
|
||||
|
||||
clean:
|
||||
rm -rf target aml
|
||||
|
||||
debug_image_build:
|
||||
docker build --no-cache --progress=plain -t debug_tgi .
|
||||
|
@ -6,10 +6,11 @@ authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[dependencies]
|
||||
async-trait = "^0.1"
|
||||
base64 = { workspace = true }
|
||||
futures = "^0.3"
|
||||
grpc-metadata = { path = "../grpc-metadata" }
|
||||
prost = "^0.12"
|
||||
rand = "0.8.5"
|
||||
thiserror = "^1.0"
|
||||
tokio = { version = "^1.32", features = ["sync"] }
|
||||
tonic = "^0.10"
|
35
backends/client/build.rs
Normal file
35
backends/client/build.rs
Normal file
@ -0,0 +1,35 @@
|
||||
use std::fs;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("cargo:rerun-if-changed=../../proto/");
|
||||
|
||||
fs::create_dir_all("src/v2/pb").unwrap_or(());
|
||||
let mut config = prost_build::Config::new();
|
||||
config.protoc_arg("--experimental_allow_proto3_optional");
|
||||
|
||||
tonic_build::configure()
|
||||
.build_client(true)
|
||||
.build_server(false)
|
||||
.out_dir("src/v2/pb")
|
||||
.include_file("mod.rs")
|
||||
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
|
||||
.map_err(|e| match e.kind(){
|
||||
std::io::ErrorKind::NotFound => {panic!("`protoc` not found, install libprotoc")},
|
||||
std::io::ErrorKind::Other => {panic!("`protoc` version unsupported, upgrade protoc: https://github.com/protocolbuffers/protobuf/releases")},
|
||||
e => {e}
|
||||
}).unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||
|
||||
fs::create_dir_all("src/v3/pb").unwrap_or(());
|
||||
let mut config = prost_build::Config::new();
|
||||
config.protoc_arg("--experimental_allow_proto3_optional");
|
||||
|
||||
tonic_build::configure()
|
||||
.build_client(true)
|
||||
.build_server(false)
|
||||
.out_dir("src/v3/pb")
|
||||
.include_file("mod.rs")
|
||||
.compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"])
|
||||
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||
|
||||
Ok(())
|
||||
}
|
91
backends/client/src/lib.rs
Normal file
91
backends/client/src/lib.rs
Normal file
@ -0,0 +1,91 @@
|
||||
//! Text Generation gRPC client library
|
||||
|
||||
pub mod v2;
|
||||
pub mod v3;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||
use thiserror::Error;
|
||||
use tonic::transport;
|
||||
use tonic::Status;
|
||||
|
||||
pub use v3::{Chunk, Image, Input, InputChunk};
|
||||
|
||||
#[async_trait]
|
||||
pub trait Health {
|
||||
/// Check if a generate server is healthy by asking it to allocate a tensor on device
|
||||
async fn device_health(&self) -> Result<()>;
|
||||
|
||||
/// Check if a generate server is healthy by doing a forward pass.
|
||||
/// EXPENSIVE
|
||||
async fn model_health(&self) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ShardInfo {
|
||||
pub requires_padding: bool,
|
||||
pub dtype: String,
|
||||
pub device_type: String,
|
||||
pub window_size: Option<u32>,
|
||||
pub speculate: u32,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug, Clone)]
|
||||
pub enum ClientError {
|
||||
#[error("Could not connect to Text Generation server: {0}")]
|
||||
Connection(String),
|
||||
#[error("Server error: {0}")]
|
||||
Generation(String),
|
||||
#[error("Sharded results are empty")]
|
||||
EmptyResults,
|
||||
}
|
||||
|
||||
impl From<Status> for ClientError {
|
||||
fn from(err: Status) -> Self {
|
||||
let err = Self::Generation(err.message().to_string());
|
||||
tracing::error!("{err}");
|
||||
err
|
||||
}
|
||||
}
|
||||
|
||||
impl From<transport::Error> for ClientError {
|
||||
fn from(err: transport::Error) -> Self {
|
||||
let err = Self::Connection(err.to_string());
|
||||
tracing::error!("{err}");
|
||||
err
|
||||
}
|
||||
}
|
||||
|
||||
// Small convenience re-wrapping of `Chunk`.
|
||||
impl From<Chunk> for InputChunk {
|
||||
fn from(chunk: Chunk) -> Self {
|
||||
InputChunk { chunk: Some(chunk) }
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert input chunks to a stringly-typed input for backwards
|
||||
/// compat for backends that haven't implemented chunked inputs.
|
||||
pub trait ChunksToString {
|
||||
/// Convert chunks to string.
|
||||
fn chunks_to_string(&self) -> String;
|
||||
}
|
||||
|
||||
impl ChunksToString for Vec<InputChunk> {
|
||||
fn chunks_to_string(&self) -> String {
|
||||
let mut output = String::new();
|
||||
self.iter().for_each(|c| match &c.chunk {
|
||||
Some(Chunk::Text(text)) => output.push_str(text),
|
||||
Some(Chunk::Image(Image { data, mimetype })) => {
|
||||
let encoded = STANDARD.encode(data);
|
||||
output.push_str(&format!("", mimetype, encoded))
|
||||
}
|
||||
// We don't create empty chunks, so this should be unreachable.
|
||||
None => unreachable!("Chunks should never be empty"),
|
||||
});
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ClientError>;
|
260
backends/client/src/v2/client.rs
Normal file
260
backends/client/src/v2/client.rs
Normal file
@ -0,0 +1,260 @@
|
||||
/// Single shard Client
|
||||
use crate::v2::pb;
|
||||
use crate::{ClientError, Result};
|
||||
|
||||
use crate::WARMUP_IMAGE_BASE64;
|
||||
use grpc_metadata::InjectTelemetryContext;
|
||||
use pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
|
||||
use pb::generate::v2::*;
|
||||
use std::cmp::min;
|
||||
use std::time::Duration;
|
||||
use tonic::transport::{Channel, Uri};
|
||||
use tracing::instrument;
|
||||
|
||||
/// Text Generation Inference gRPC client
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Client {
|
||||
stub: TextGenerationServiceClient<Channel>,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
/// Returns a client connected to the given url
|
||||
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||
let channel = Channel::builder(uri).connect().await?;
|
||||
|
||||
Ok(Self {
|
||||
stub: TextGenerationServiceClient::new(channel),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a client connected to the given unix socket
|
||||
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||
let channel = Channel::from_shared("http://[::]:50051".to_string())
|
||||
.unwrap()
|
||||
.connect_with_connector(tower::service_fn(move |_: Uri| {
|
||||
tokio::net::UnixStream::connect(path.clone())
|
||||
}))
|
||||
.await?;
|
||||
|
||||
Ok(Self {
|
||||
stub: TextGenerationServiceClient::new(channel),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a list of uris or unix sockets of all shards
|
||||
#[instrument(skip(self))]
|
||||
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
||||
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
|
||||
let response = self.stub.service_discovery(request).await.map_err(|_| {
|
||||
ClientError::Connection("Server does not support v2 interface".to_string())
|
||||
})?;
|
||||
let urls = response
|
||||
.into_inner()
|
||||
.urls
|
||||
.into_iter()
|
||||
// Remove unix socket prefix
|
||||
.map(|url| match url.strip_prefix("unix://") {
|
||||
None => url,
|
||||
Some(stripped_url) => stripped_url.to_string(),
|
||||
})
|
||||
.collect();
|
||||
Ok(urls)
|
||||
}
|
||||
|
||||
/// Get model info
|
||||
#[instrument(skip(self))]
|
||||
pub async fn info(&mut self) -> Result<InfoResponse> {
|
||||
let request = tonic::Request::new(InfoRequest {}).inject_context();
|
||||
let response = self.stub.info(request).await?.into_inner();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Get model health
|
||||
#[instrument(skip(self))]
|
||||
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||
let request = tonic::Request::new(HealthRequest {}).inject_context();
|
||||
let response = self.stub.health(request).await?.into_inner();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
|
||||
self.stub.clear_cache(request).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Filter a cached batch
|
||||
#[instrument(skip(self))]
|
||||
pub async fn filter_batch(
|
||||
&mut self,
|
||||
batch_id: u64,
|
||||
request_ids: Vec<u64>,
|
||||
) -> Result<Option<CachedBatch>> {
|
||||
let request = tonic::Request::new(FilterBatchRequest {
|
||||
batch_id,
|
||||
request_ids,
|
||||
})
|
||||
.inject_context();
|
||||
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
||||
Ok(filtered_batch.batch)
|
||||
}
|
||||
|
||||
/// Warmup on a max size batch
|
||||
///
|
||||
/// Returns the maximum amount of tokens supported by the hardware
|
||||
#[instrument(skip_all)]
|
||||
pub async fn warmup(
|
||||
&mut self,
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let mut n_tokens = 0;
|
||||
let mut requests = Vec::new();
|
||||
// Create requests
|
||||
while n_tokens < max_prefill_tokens {
|
||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||
|
||||
let mut inputs = String::new();
|
||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||
if n_tokens == 0 {
|
||||
// 1 request is enough to test vision heads.
|
||||
// Sending images on other queries messes up easily with truncation.
|
||||
inputs.push_str(&format!(
|
||||
"",
|
||||
));
|
||||
}
|
||||
|
||||
requests.push(Request {
|
||||
id: 0,
|
||||
inputs,
|
||||
// We truncate the input on the server side to be sure that it has the correct size
|
||||
truncate,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
top_k: 10,
|
||||
top_p: 0.9,
|
||||
typical_p: 0.9,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.2,
|
||||
frequency_penalty: 0.1,
|
||||
watermark: true,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: max_total_tokens - truncate,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: true,
|
||||
}),
|
||||
prefill_logprobs: true,
|
||||
top_n_tokens: 20,
|
||||
});
|
||||
n_tokens += max_input_length;
|
||||
|
||||
// Check max_batch_size
|
||||
if Some(requests.len()) == max_batch_size {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let batch = Batch {
|
||||
id: 0,
|
||||
size: requests.len() as u32,
|
||||
requests,
|
||||
max_tokens: 0,
|
||||
};
|
||||
|
||||
let request = tonic::Request::new(WarmupRequest {
|
||||
batch: Some(batch),
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_total_tokens,
|
||||
})
|
||||
.inject_context();
|
||||
let response = self.stub.warmup(request).await?.into_inner();
|
||||
Ok(response.max_supported_total_tokens)
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given batch
|
||||
///
|
||||
/// Returns Generation for each request in batch
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
||||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||
let response = self.stub.prefill(request).await?.into_inner();
|
||||
Ok((
|
||||
response.generations,
|
||||
response.batch,
|
||||
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
|
||||
))
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given cached batches
|
||||
///
|
||||
/// Returns Generation for each request in batches
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batches: Vec<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||
let response = self.stub.decode(request).await?.into_inner();
|
||||
Ok((
|
||||
response.generations,
|
||||
response.batch,
|
||||
DecodeTimings::new(
|
||||
response.concat_ns,
|
||||
response.forward_ns,
|
||||
response.decode_ns,
|
||||
response.total_ns,
|
||||
),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PrefillTimings {
|
||||
pub forward: Duration,
|
||||
pub decode: Duration,
|
||||
pub total: Duration,
|
||||
}
|
||||
|
||||
impl PrefillTimings {
|
||||
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||
Self {
|
||||
forward: Duration::from_nanos(forward_ns),
|
||||
decode: Duration::from_nanos(decode_ns),
|
||||
total: Duration::from_nanos(total_ns),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DecodeTimings {
|
||||
pub concat: Option<Duration>,
|
||||
pub forward: Duration,
|
||||
pub decode: Duration,
|
||||
pub total: Duration,
|
||||
}
|
||||
|
||||
impl DecodeTimings {
|
||||
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||
Self {
|
||||
concat: concat_ns.map(Duration::from_nanos),
|
||||
forward: Duration::from_nanos(forward_ns),
|
||||
decode: Duration::from_nanos(decode_ns),
|
||||
total: Duration::from_nanos(total_ns),
|
||||
}
|
||||
}
|
||||
}
|
13
backends/client/src/v2/mod.rs
Normal file
13
backends/client/src/v2/mod.rs
Normal file
@ -0,0 +1,13 @@
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
mod pb;
|
||||
|
||||
mod client;
|
||||
mod sharded_client;
|
||||
|
||||
pub use client::Client;
|
||||
pub use pb::generate::v2::HealthResponse;
|
||||
pub use pb::generate::v2::{
|
||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, InfoResponse,
|
||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
@ -1,12 +1,17 @@
|
||||
/// Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||
|
||||
use crate::client::{DecodeTimings, PrefillTimings};
|
||||
/// Multi shard Client
|
||||
use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo};
|
||||
use crate::{v2, Health, ShardInfo};
|
||||
use crate::{ClientError, Result};
|
||||
|
||||
use crate::v2::InfoResponse;
|
||||
use async_trait::async_trait;
|
||||
use futures::future::join_all;
|
||||
use tonic::transport::Uri;
|
||||
use tracing::instrument;
|
||||
use v2::client::{DecodeTimings, PrefillTimings};
|
||||
use v2::{
|
||||
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Text Generation Inference gRPC multi client
|
||||
@ -49,7 +54,7 @@ impl ShardedClient {
|
||||
.iter_mut()
|
||||
.map(|client| client.info())
|
||||
.collect();
|
||||
join_all(futures).await.pop().unwrap()
|
||||
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
|
||||
}
|
||||
|
||||
/// GRPC health check
|
||||
@ -99,8 +104,8 @@ impl ShardedClient {
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
model_id: &str,
|
||||
) -> Result<Option<u32>> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
@ -110,8 +115,8 @@ impl ShardedClient {
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_batch_size,
|
||||
model_id
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
@ -189,3 +194,60 @@ impl ShardedClient {
|
||||
Ok((generations, next_batch, timings))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<InfoResponse> for ShardInfo {
|
||||
fn from(value: InfoResponse) -> Self {
|
||||
Self {
|
||||
requires_padding: value.requires_padding,
|
||||
dtype: value.dtype,
|
||||
device_type: value.device_type,
|
||||
window_size: value.window_size,
|
||||
speculate: value.speculate,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Health for ShardedClient {
|
||||
async fn device_health(&self) -> Result<()> {
|
||||
self.clone().health().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn model_health(&self) -> Result<()> {
|
||||
// Dummy batch of 1 token and 1 generated token
|
||||
let liveness_request = Request {
|
||||
id: u64::MAX,
|
||||
inputs: "liveness".to_string(),
|
||||
truncate: 10,
|
||||
prefill_logprobs: false,
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 1.0,
|
||||
top_k: 0,
|
||||
top_p: 1.0,
|
||||
typical_p: 1.0,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.0,
|
||||
frequency_penalty: 0.0,
|
||||
watermark: false,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 1,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: false,
|
||||
}),
|
||||
top_n_tokens: 0,
|
||||
};
|
||||
let batch = Batch {
|
||||
id: u64::MAX,
|
||||
requests: vec![liveness_request],
|
||||
size: 1,
|
||||
max_tokens: 2,
|
||||
};
|
||||
self.clone().prefill(batch).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
288
backends/client/src/v3/client.rs
Normal file
288
backends/client/src/v3/client.rs
Normal file
@ -0,0 +1,288 @@
|
||||
use crate::v3::{pb, Chunk};
|
||||
use crate::{ClientError, Result, WARMUP_IMAGE_BASE64};
|
||||
/// Single shard Client
|
||||
use base64::engine::general_purpose::STANDARD;
|
||||
use base64::Engine;
|
||||
use grpc_metadata::InjectTelemetryContext;
|
||||
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
|
||||
use pb::generate::v3::*;
|
||||
use std::cmp::min;
|
||||
use std::time::Duration;
|
||||
use tonic::transport::{Channel, Uri};
|
||||
use tracing::instrument;
|
||||
|
||||
/// Text Generation Inference gRPC client
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Client {
|
||||
stub: TextGenerationServiceClient<Channel>,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
/// Returns a client connected to the given url
|
||||
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||
let channel = Channel::builder(uri).connect().await?;
|
||||
|
||||
Ok(Self {
|
||||
stub: TextGenerationServiceClient::new(channel),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a client connected to the given unix socket
|
||||
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||
let channel = Channel::from_shared("http://[::]:50051".to_string())
|
||||
.unwrap()
|
||||
.connect_with_connector(tower::service_fn(move |_: Uri| {
|
||||
tokio::net::UnixStream::connect(path.clone())
|
||||
}))
|
||||
.await?;
|
||||
|
||||
Ok(Self {
|
||||
stub: TextGenerationServiceClient::new(channel),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a list of uris or unix sockets of all shards
|
||||
#[instrument(skip(self))]
|
||||
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
||||
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
|
||||
let response = self.stub.service_discovery(request).await.map_err(|_| {
|
||||
ClientError::Connection("Server does not support v3 interface".to_string())
|
||||
})?;
|
||||
let urls = response
|
||||
.into_inner()
|
||||
.urls
|
||||
.into_iter()
|
||||
// Remove unix socket prefix
|
||||
.map(|url| match url.strip_prefix("unix://") {
|
||||
None => url,
|
||||
Some(stripped_url) => stripped_url.to_string(),
|
||||
})
|
||||
.collect();
|
||||
Ok(urls)
|
||||
}
|
||||
|
||||
/// Get model info
|
||||
#[instrument(skip(self))]
|
||||
pub async fn info(&mut self) -> Result<InfoResponse> {
|
||||
let request = tonic::Request::new(InfoRequest {}).inject_context();
|
||||
let response = self.stub.info(request).await?.into_inner();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Get model health
|
||||
#[instrument(skip(self))]
|
||||
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||
let request = tonic::Request::new(HealthRequest {}).inject_context();
|
||||
let response = self.stub.health(request).await?.into_inner();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
|
||||
self.stub.clear_cache(request).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Filter a cached batch
|
||||
#[instrument(skip(self))]
|
||||
pub async fn filter_batch(
|
||||
&mut self,
|
||||
batch_id: u64,
|
||||
request_ids: Vec<u64>,
|
||||
) -> Result<Option<CachedBatch>> {
|
||||
let request = tonic::Request::new(FilterBatchRequest {
|
||||
batch_id,
|
||||
request_ids,
|
||||
})
|
||||
.inject_context();
|
||||
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
||||
Ok(filtered_batch.batch)
|
||||
}
|
||||
|
||||
/// Warmup on a max size batch
|
||||
///
|
||||
/// Returns the maximum amount of tokens supported by the hardware
|
||||
#[instrument(skip_all)]
|
||||
pub async fn warmup(
|
||||
&mut self,
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let mut n_tokens = 0;
|
||||
let mut requests = Vec::new();
|
||||
// Create requests
|
||||
while n_tokens < max_prefill_tokens {
|
||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||
|
||||
let mut input_chunks = Vec::new();
|
||||
input_chunks
|
||||
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
|
||||
if n_tokens == 0 {
|
||||
input_chunks.push(
|
||||
Chunk::Image(Image {
|
||||
// Safe unwrap, because we control the data.
|
||||
data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(),
|
||||
mimetype: "image/jpeg;base64".to_string(),
|
||||
})
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
|
||||
// Send stringly-typed inputs for compatibility for backends that haven't
|
||||
// been updated to support chunks.
|
||||
|
||||
let mut inputs = String::new();
|
||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||
if n_tokens == 0 {
|
||||
// 1 request is enough to test vision heads.
|
||||
// Sending images on other queries messes up easily with truncation.
|
||||
inputs.push_str(&format!(
|
||||
"",
|
||||
));
|
||||
}
|
||||
|
||||
requests.push(Request {
|
||||
id: 0,
|
||||
inputs,
|
||||
input_chunks: Some(Input {
|
||||
chunks: input_chunks,
|
||||
}),
|
||||
// We truncate the input on the server side to be sure that it has the correct size
|
||||
truncate,
|
||||
// Most request will have that
|
||||
add_special_tokens: true,
|
||||
// Blocks and slots will be set on the server side if we use paged attention
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
prefix_len: 0,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
top_k: 10,
|
||||
top_p: 0.9,
|
||||
typical_p: 0.9,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.2,
|
||||
frequency_penalty: 0.1,
|
||||
watermark: true,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: max_total_tokens - truncate,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: true,
|
||||
}),
|
||||
prefill_logprobs: true,
|
||||
top_n_tokens: 20,
|
||||
adapter_id: None,
|
||||
});
|
||||
n_tokens += max_input_length;
|
||||
|
||||
// Check max_batch_size
|
||||
if Some(requests.len()) == max_batch_size {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let batch = Batch {
|
||||
id: 0,
|
||||
size: requests.len() as u32,
|
||||
requests,
|
||||
max_tokens: max_input_length,
|
||||
max_blocks: 0,
|
||||
};
|
||||
|
||||
let request = tonic::Request::new(WarmupRequest {
|
||||
batch: Some(batch),
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_total_tokens,
|
||||
})
|
||||
.inject_context();
|
||||
let response = self.stub.warmup(request).await?.into_inner();
|
||||
Ok(response.max_supported_total_tokens)
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given batch
|
||||
///
|
||||
/// Returns Generation for each request in batch
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
||||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||
let response = self.stub.prefill(request).await?.into_inner();
|
||||
Ok((
|
||||
response.generations,
|
||||
response.batch,
|
||||
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
|
||||
))
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given cached batches
|
||||
///
|
||||
/// Returns Generation for each request in batches
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batches: Vec<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||
let response = self.stub.decode(request).await?.into_inner();
|
||||
Ok((
|
||||
response.generations,
|
||||
response.batch,
|
||||
DecodeTimings::new(
|
||||
response.concat_ns,
|
||||
response.forward_ns,
|
||||
response.decode_ns,
|
||||
response.total_ns,
|
||||
),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PrefillTimings {
|
||||
pub forward: Duration,
|
||||
pub decode: Duration,
|
||||
pub total: Duration,
|
||||
}
|
||||
|
||||
impl PrefillTimings {
|
||||
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||
Self {
|
||||
forward: Duration::from_nanos(forward_ns),
|
||||
decode: Duration::from_nanos(decode_ns),
|
||||
total: Duration::from_nanos(total_ns),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DecodeTimings {
|
||||
pub concat: Option<Duration>,
|
||||
pub forward: Duration,
|
||||
pub decode: Duration,
|
||||
pub total: Duration,
|
||||
}
|
||||
|
||||
impl DecodeTimings {
|
||||
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||
Self {
|
||||
concat: concat_ns.map(Duration::from_nanos),
|
||||
forward: Duration::from_nanos(forward_ns),
|
||||
decode: Duration::from_nanos(decode_ns),
|
||||
total: Duration::from_nanos(total_ns),
|
||||
}
|
||||
}
|
||||
}
|
13
backends/client/src/v3/mod.rs
Normal file
13
backends/client/src/v3/mod.rs
Normal file
@ -0,0 +1,13 @@
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
mod pb;
|
||||
|
||||
mod client;
|
||||
mod sharded_client;
|
||||
|
||||
pub use client::Client;
|
||||
pub use pb::generate::v3::{
|
||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||
StoppingCriteriaParameters, Tokens,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
263
backends/client/src/v3/sharded_client.rs
Normal file
263
backends/client/src/v3/sharded_client.rs
Normal file
@ -0,0 +1,263 @@
|
||||
/// Multi shard Client
|
||||
use crate::{v3, Health, ShardInfo};
|
||||
use crate::{ClientError, Result};
|
||||
|
||||
use crate::v3::{Chunk, InfoResponse, Input};
|
||||
use async_trait::async_trait;
|
||||
use futures::future::join_all;
|
||||
use tonic::transport::Uri;
|
||||
use tracing::instrument;
|
||||
use v3::client::{DecodeTimings, PrefillTimings};
|
||||
use v3::{
|
||||
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Text Generation Inference gRPC multi client
|
||||
pub struct ShardedClient {
|
||||
clients: Vec<Client>,
|
||||
}
|
||||
|
||||
impl ShardedClient {
|
||||
fn new(clients: Vec<Client>) -> Self {
|
||||
Self { clients }
|
||||
}
|
||||
|
||||
/// Create a new ShardedClient from a master client. The master client will communicate with
|
||||
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
|
||||
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
||||
// Get all uris/unix sockets from the master client
|
||||
let uris = master_client.service_discovery().await?;
|
||||
let futures = uris.into_iter().map(Client::connect_uds);
|
||||
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
|
||||
Ok(Self::new(clients?))
|
||||
}
|
||||
|
||||
/// Returns a client connected to the given uri
|
||||
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||
let master_client = Client::connect(uri).await?;
|
||||
Self::from_master_client(master_client).await
|
||||
}
|
||||
|
||||
/// Returns a client connected to the given unix socket
|
||||
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||
let master_client = Client::connect_uds(path).await?;
|
||||
Self::from_master_client(master_client).await
|
||||
}
|
||||
|
||||
/// Get the model info
|
||||
#[instrument(skip(self))]
|
||||
pub async fn info(&mut self) -> Result<ShardInfo> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.info())
|
||||
.collect();
|
||||
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
|
||||
}
|
||||
|
||||
/// GRPC health check
|
||||
#[instrument(skip(self))]
|
||||
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.health())
|
||||
.collect();
|
||||
join_all(futures).await.pop().unwrap()
|
||||
}
|
||||
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.clear_cache(batch_id))
|
||||
.collect();
|
||||
join_all(futures).await.into_iter().collect()
|
||||
}
|
||||
|
||||
/// Filter a cached batch
|
||||
#[instrument(skip(self))]
|
||||
pub async fn filter_batch(
|
||||
&mut self,
|
||||
batch_id: u64,
|
||||
request_ids: Vec<u64>,
|
||||
) -> Result<Option<CachedBatch>> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
|
||||
.collect();
|
||||
// all shards return the same message
|
||||
join_all(futures).await.pop().unwrap()
|
||||
}
|
||||
|
||||
/// Warmup on a max size batch
|
||||
///
|
||||
/// Returns the maximum amount of tokens supported by the hardware
|
||||
#[instrument(skip(self))]
|
||||
pub async fn warmup(
|
||||
&mut self,
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| {
|
||||
Box::pin(client.warmup(
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_batch_size,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
// Take the minimum value
|
||||
let results = join_all(futures)
|
||||
.await
|
||||
.into_iter()
|
||||
.collect::<Result<Vec<Option<u32>>>>()?;
|
||||
Ok(results.into_iter().flatten().min())
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given batch
|
||||
///
|
||||
/// Returns Generation for each request in batch
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
|
||||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||
.collect();
|
||||
#[allow(clippy::type_complexity)]
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||
join_all(futures).await.into_iter().collect();
|
||||
let mut results = results?;
|
||||
|
||||
let (mut generations, next_batch, mut timings) =
|
||||
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||
|
||||
// Merge generations from different model shards
|
||||
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||
generations.append(&mut shard_generations);
|
||||
// Return the timings of the slowest shard
|
||||
if shard_timings.total > timings.total {
|
||||
timings = shard_timings;
|
||||
}
|
||||
}
|
||||
Ok((generations, next_batch, timings))
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given cached batches
|
||||
///
|
||||
/// Returns Generation for each request in batches
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batches: Vec<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||
.collect();
|
||||
#[allow(clippy::type_complexity)]
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
|
||||
join_all(futures).await.into_iter().collect();
|
||||
let mut results = results?;
|
||||
|
||||
let (mut generations, next_batch, mut timings) =
|
||||
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||
|
||||
// Merge generations from different model shards
|
||||
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||
generations.append(&mut shard_generations);
|
||||
// Return the timings of the slowest shard
|
||||
if shard_timings.total > timings.total {
|
||||
timings = shard_timings;
|
||||
}
|
||||
}
|
||||
Ok((generations, next_batch, timings))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<InfoResponse> for ShardInfo {
|
||||
fn from(value: InfoResponse) -> Self {
|
||||
Self {
|
||||
requires_padding: value.requires_padding,
|
||||
dtype: value.dtype,
|
||||
device_type: value.device_type,
|
||||
window_size: value.window_size,
|
||||
speculate: value.speculate,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Health for ShardedClient {
|
||||
async fn device_health(&self) -> Result<()> {
|
||||
self.clone().health().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn model_health(&self) -> Result<()> {
|
||||
// Dummy batch of 1 token and 1 generated token
|
||||
let liveness_request = Request {
|
||||
id: u64::MAX,
|
||||
inputs: "liveness".to_string(),
|
||||
input_chunks: Some(Input {
|
||||
chunks: vec![Chunk::Text("liveness".into()).into()],
|
||||
}),
|
||||
truncate: 10,
|
||||
add_special_tokens: true,
|
||||
prefill_logprobs: false,
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 1.0,
|
||||
top_k: 0,
|
||||
top_p: 1.0,
|
||||
typical_p: 1.0,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.0,
|
||||
frequency_penalty: 0.0,
|
||||
watermark: false,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 1,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: false,
|
||||
}),
|
||||
top_n_tokens: 0,
|
||||
// Block 0 is reserved for health checks
|
||||
blocks: vec![0],
|
||||
slots: (0..16).collect(),
|
||||
prefix_len: 0,
|
||||
adapter_id: None,
|
||||
};
|
||||
let batch = Batch {
|
||||
id: u64::MAX,
|
||||
requests: vec![liveness_request],
|
||||
size: 1,
|
||||
max_tokens: 2,
|
||||
max_blocks: 1,
|
||||
};
|
||||
self.clone().prefill(batch).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
63
backends/trtllm/CMakeLists.txt
Normal file
63
backends/trtllm/CMakeLists.txt
Normal file
@ -0,0 +1,63 @@
|
||||
cmake_minimum_required(VERSION 3.20)
|
||||
|
||||
project(tgi-trtllm-backend VERSION 1.0.0)
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
|
||||
include(FetchContent)
|
||||
include(ExternalProject)
|
||||
|
||||
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
|
||||
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
|
||||
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
|
||||
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located")
|
||||
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located")
|
||||
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
|
||||
|
||||
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
|
||||
find_package(CUDAToolkit 12.5 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
|
||||
|
||||
#### External dependencies ####
|
||||
include(cmake/fmt.cmake)
|
||||
include(cmake/json.cmake)
|
||||
include(cmake/spdlog.cmake)
|
||||
include(cmake/trtllm.cmake)
|
||||
|
||||
# Let's build TRTLLM as part of CMake
|
||||
add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..")
|
||||
|
||||
# Tell CMake to need try to override the RPATH for executorWorker as it has not information on how to do so
|
||||
set_target_properties(executorWorker PROPERTIES SKIP_BUILD_RPATH TRUE)
|
||||
|
||||
# TGI TRTLLM Backend definition
|
||||
add_library(tgi_trtllm_backend_impl STATIC include/backend.h lib/backend.cpp include/hardware.h)
|
||||
include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
|
||||
target_include_directories(tgi_trtllm_backend_impl PRIVATE
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
$<INSTALL_INTERFACE:include>
|
||||
)
|
||||
target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
|
||||
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper CUDA::cudart CUDA::nvml)
|
||||
target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt)
|
||||
|
||||
# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back
|
||||
install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker)
|
||||
install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB)
|
||||
|
||||
#### Unit Tests ####
|
||||
if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
|
||||
message(STATUS "Building tests")
|
||||
FetchContent_Declare(
|
||||
Catch2
|
||||
GIT_REPOSITORY https://github.com/catchorg/Catch2
|
||||
GIT_TAG v3.6.0
|
||||
)
|
||||
FetchContent_MakeAvailable(Catch2)
|
||||
|
||||
# add_executable(tgi_trtllm_backend_tests tests/infer_test.cpp)
|
||||
# target_link_libraries(tgi_trtllm_backend_tests PRIVATE tgi_trtllm_backend_impl Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt CUDA::cudart CUDA::nvml)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
|
||||
include(CTest)
|
||||
include(Catch)
|
||||
# catch_discover_tests(tgi_trtllm_backend_tests)
|
||||
endif ()
|
27
backends/trtllm/Cargo.toml
Normal file
27
backends/trtllm/Cargo.toml
Normal file
@ -0,0 +1,27 @@
|
||||
[package]
|
||||
name = "text-generation-backends-trtllm"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1"
|
||||
async-stream = "0.3"
|
||||
clap = { version = "4.5", features = ["derive"] }
|
||||
cxx = "1.0"
|
||||
log = { version = "0.4", features = [] }
|
||||
text-generation-router = { path = "../../router" }
|
||||
tokenizers = { version = "0.19", features = ["hf-hub"] }
|
||||
tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tokio-stream = "0.1.15"
|
||||
thiserror = "1.0.62"
|
||||
tracing = "0.1"
|
||||
tracing-opentelemetry = "0.24"
|
||||
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
|
||||
parking_lot = "0.12"
|
||||
|
||||
[build-dependencies]
|
||||
cmake = "0.1"
|
||||
cxx-build = { version = "1.0", features = ["parallel"] }
|
||||
pkg-config = "0.3"
|
101
backends/trtllm/Dockerfile
Normal file
101
backends/trtllm/Dockerfile
Normal file
@ -0,0 +1,101 @@
|
||||
ARG CUDA_ARCH_LIST="75-real;80-real;86-real;89-real;90-real"
|
||||
ARG OMPI_VERSION="4.1.6"
|
||||
|
||||
# Build dependencies resolver stage
|
||||
FROM lukemathwalker/cargo-chef:latest AS chef
|
||||
WORKDIR /usr/src/text-generation-inference/backends/trtllm
|
||||
|
||||
FROM chef AS planner
|
||||
COPY . .
|
||||
RUN cargo chef prepare --recipe-path recipe.json
|
||||
|
||||
# CUDA dependent dependencies resolver stage
|
||||
FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu22.04 AS cuda-builder
|
||||
|
||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
apt update && apt install -y \
|
||||
build-essential \
|
||||
cmake \
|
||||
curl \
|
||||
gcc \
|
||||
g++ \
|
||||
git \
|
||||
git-lfs \
|
||||
libssl-dev \
|
||||
ninja-build \
|
||||
pkg-config \
|
||||
python3 \
|
||||
python3-setuptools \
|
||||
tar \
|
||||
wget
|
||||
|
||||
ENV TGI_INSTALL_PREFIX=/usr/local/tgi
|
||||
ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt
|
||||
|
||||
# Install OpenMPI
|
||||
FROM cuda-builder AS mpi-builder
|
||||
ARG OMPI_VERSION
|
||||
|
||||
ENV OMPI_TARBALL_FILENAME="openmpi-$OMPI_VERSION.tar.bz2"
|
||||
RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILENAME" -P /opt/src && \
|
||||
mkdir /usr/src/mpi && \
|
||||
tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \
|
||||
cd /usr/src/mpi && \
|
||||
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda && \
|
||||
make -j all && \
|
||||
make install && \
|
||||
rm -rf "/opt/src/$OMPI_TARBALL_FILENAME"
|
||||
|
||||
# Install TensorRT
|
||||
FROM cuda-builder AS trt-builder
|
||||
COPY backends/trtllm/scripts/install_tensorrt.sh /opt/install_tensorrt.sh
|
||||
RUN chmod +x /opt/install_tensorrt.sh && \
|
||||
/opt/install_tensorrt.sh
|
||||
|
||||
# Build Backend
|
||||
FROM cuda-builder AS tgi-builder
|
||||
WORKDIR /usr/src/text-generation-inference
|
||||
|
||||
# Install Rust
|
||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \
|
||||
chmod -R a+w /root/.rustup && \
|
||||
chmod -R a+w /root/.cargo
|
||||
|
||||
ENV PATH="/root/.cargo/bin:$PATH"
|
||||
RUN cargo install cargo-chef
|
||||
|
||||
# Cache dependencies
|
||||
COPY --from=planner /usr/src/text-generation-inference/backends/trtllm/recipe.json .
|
||||
RUN cargo chef cook --release --recipe-path recipe.json
|
||||
|
||||
# Build actual TGI
|
||||
ARG CUDA_ARCH_LIST
|
||||
ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt:$CMAKE_PREFIX_PATH"
|
||||
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
|
||||
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig:$PKG_CONFIG_PATH"
|
||||
|
||||
COPY . .
|
||||
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
||||
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
||||
RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \
|
||||
cd backends/trtllm && \
|
||||
CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release
|
||||
|
||||
FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04 AS runtime
|
||||
WORKDIR /usr/local/tgi/bin
|
||||
|
||||
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
|
||||
|
||||
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
||||
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
||||
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
|
||||
COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
|
||||
|
||||
FROM runtime
|
||||
|
||||
LABEL co.huggingface.vendor="Hugging Face Inc."
|
||||
LABEL org.opencontainers.image.authors="hardware@hf.co"
|
||||
|
||||
ENTRYPOINT ["./text-generation-launcher"]
|
||||
CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"]
|
46
backends/trtllm/README.md
Normal file
46
backends/trtllm/README.md
Normal file
@ -0,0 +1,46 @@
|
||||
# Text Generation Inference - TensorRT-LLM Backend Implementation
|
||||
|
||||
## Description
|
||||
|
||||
This folder provides the sources of the TensorRT-LLM backend implementation powered by TensorRT-LLM Executor new API
|
||||
|
||||
## Simplified Request Sequence
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
actor User
|
||||
participant TextGenerationInference.HttpServer
|
||||
participant TextGenerationInference.TensorRtLlmBackend
|
||||
participant TextGenerationInference.TensorRtLlmWorkerThread
|
||||
participant TensorRtLlm.Executor
|
||||
participant Nvidia.Gpu
|
||||
User ->> TextGenerationInference.HttpServer: POST /generate
|
||||
TextGenerationInference.HttpServer ->> TextGenerationInference.TensorRtLlmBackend: Validate and forward inputs & parameters
|
||||
TextGenerationInference.TensorRtLlmBackend ->> TextGenerationInference.TensorRtLlmWorkerThread: Allocate a new context and spawn a new thread to handle the request
|
||||
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Submit the request to the In-Flight Batcher
|
||||
activate Nvidia.Gpu
|
||||
TensorRtLlm.Executor ->> Nvidia.Gpu: Add the request to the poll for execution
|
||||
TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Response with an unique request identifier
|
||||
rect rgb(10, 92, 54)
|
||||
loop every 100us
|
||||
rect rgb(15, 81, 50)
|
||||
alt Acquire lock to query executor
|
||||
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Poll request number of new token(s) generated
|
||||
else There are new generated tokens
|
||||
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Retrieve newly generated tokens
|
||||
TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Return decoded token information and potential error (omitted)
|
||||
rect rgb(11, 110, 79)
|
||||
alt Generated token is final
|
||||
TensorRtLlm.Executor ->> Nvidia.Gpu: Remove request from the scheduler and from the GPU
|
||||
TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream the remaining decoded tokens and flush the connection
|
||||
else Generated token is not final
|
||||
TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream token back to the user as they get decoded
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
deactivate Nvidia.Gpu
|
||||
end
|
||||
end
|
||||
|
||||
```
|
150
backends/trtllm/build.rs
Normal file
150
backends/trtllm/build.rs
Normal file
@ -0,0 +1,150 @@
|
||||
use cxx_build::CFG;
|
||||
use pkg_config;
|
||||
use std::env;
|
||||
use std::env::consts::ARCH;
|
||||
use std::path::{absolute, PathBuf};
|
||||
|
||||
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
|
||||
const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST");
|
||||
const CUDA_REQUIRED_VERSION: &str = "12.5";
|
||||
const MPI_REQUIRED_VERSION: &str = "4.1";
|
||||
const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX");
|
||||
const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR");
|
||||
const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR");
|
||||
|
||||
// Dependencies
|
||||
const BACKEND_DEPS: [&str; 2] = ["tgi_trtllm_backend_impl", "tgi_trtllm_backend"];
|
||||
const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"];
|
||||
const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [
|
||||
("dylib", "tensorrt_llm"),
|
||||
("static", "tensorrt_llm_executor_static"),
|
||||
("dylib", "tensorrt_llm_nvrtc_wrapper"),
|
||||
("dylib", "nvinfer_plugin_tensorrt_llm"),
|
||||
("dylib", "decoder_attention"),
|
||||
];
|
||||
|
||||
macro_rules! probe {
|
||||
($name: expr, $version: expr) => {
|
||||
if let Err(_) = pkg_config::probe_library($name) {
|
||||
pkg_config::probe_library(&format!("{}-{}", $name, $version))
|
||||
.expect(&format!("Failed to locate {}", $name));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) {
|
||||
// Build the backend implementation through CMake
|
||||
let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
|
||||
let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt");
|
||||
let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("90-real"); // Hopper by default
|
||||
|
||||
let mut install_path = PathBuf::from(install_path);
|
||||
if !install_path.is_absolute() {
|
||||
install_path = absolute(out_dir).expect("cannot happen").join(install_path);
|
||||
}
|
||||
|
||||
let _ = cmake::Config::new(".")
|
||||
.uses_cxx11()
|
||||
.generator("Ninja")
|
||||
.profile(match is_debug {
|
||||
true => "Debug",
|
||||
false => "Release",
|
||||
})
|
||||
.env("OPT_LEVEL", opt_level)
|
||||
.define("CMAKE_INSTALL_PREFIX", &install_path)
|
||||
.define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc")
|
||||
.define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list)
|
||||
.define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path)
|
||||
.build();
|
||||
|
||||
// Additional transitive CMake dependencies
|
||||
let deps_folder = out_dir.join("build").join("_deps");
|
||||
for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES {
|
||||
let dep_name = match is_debug {
|
||||
true => format!("{}d", dependency),
|
||||
false => String::from(dependency),
|
||||
};
|
||||
let dep_path = deps_folder.join(format!("{}-build", dependency));
|
||||
println!("cargo:rustc-link-search={}", dep_path.display());
|
||||
println!("cargo:rustc-link-lib=static={}", dep_name);
|
||||
}
|
||||
|
||||
// Emit linkage information from the artifacts we just built
|
||||
let install_lib_path = install_path.join("lib");
|
||||
|
||||
println!(
|
||||
r"cargo:warning=Adding link search path: {}",
|
||||
install_lib_path.display()
|
||||
);
|
||||
println!(r"cargo:rustc-link-search={}", install_lib_path.display());
|
||||
|
||||
(PathBuf::from(install_path), deps_folder)
|
||||
}
|
||||
|
||||
fn build_ffi_layer(deps_folder: &PathBuf) {
|
||||
CFG.include_prefix = "backends/trtllm";
|
||||
cxx_build::bridge("src/lib.rs")
|
||||
.static_flag(true)
|
||||
.include(deps_folder.join("fmt-src").join("include"))
|
||||
.include(deps_folder.join("spdlog-src").join("include"))
|
||||
.include(deps_folder.join("json-src").join("include"))
|
||||
.include(deps_folder.join("trtllm-src").join("cpp").join("include"))
|
||||
.include("/usr/local/cuda/include")
|
||||
.include("/usr/local/tensorrt/include")
|
||||
.file("src/ffi.cpp")
|
||||
.std("c++20")
|
||||
.compile("tgi_trtllm_backend");
|
||||
|
||||
println!("cargo:rerun-if-changed=CMakeLists.txt");
|
||||
println!("cargo:rerun-if-changed=include/backend.h");
|
||||
println!("cargo:rerun-if-changed=lib/backend.cpp");
|
||||
println!("cargo:rerun-if-changed=include/ffi.h");
|
||||
println!("cargo:rerun-if-changed=src/ffi.cpp");
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// Misc variables
|
||||
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
|
||||
let build_profile = env::var("PROFILE").unwrap();
|
||||
let (is_debug, opt_level) = match build_profile.as_ref() {
|
||||
"debug" => (true, "0"),
|
||||
_ => (false, "3"),
|
||||
};
|
||||
|
||||
// Build the backend
|
||||
let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir);
|
||||
|
||||
// Build the FFI layer calling the backend above
|
||||
build_ffi_layer(&deps_folder);
|
||||
|
||||
// Emit linkage search path
|
||||
probe!("ompi", MPI_REQUIRED_VERSION);
|
||||
|
||||
// Probe CUDA & co. with pkg-config
|
||||
CUDA_TRANSITIVE_DEPS.iter().for_each(|name| {
|
||||
probe!(name, CUDA_REQUIRED_VERSION);
|
||||
});
|
||||
|
||||
// NCCL is slightly trickier because it might not have a pkgconfig installed
|
||||
let nccl_library_path_default = format!("/usr/local/{}-linux-gnu", ARCH);
|
||||
let nccl_library_path = NCCL_ROOT_DIR.unwrap_or(&nccl_library_path_default);
|
||||
println!(r"cargo:rustc-link-search=native={}", nccl_library_path);
|
||||
println!("cargo:rustc-link-lib=dylib=nccl");
|
||||
|
||||
// TensorRT
|
||||
let tensort_library_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt/lib");
|
||||
println!(r"cargo:rustc-link-search=native={}", tensort_library_path);
|
||||
println!("cargo:rustc-link-lib=dylib=nvinfer");
|
||||
|
||||
// TensorRT-LLM
|
||||
TENSORRT_LLM_TRANSITIVE_DEPS
|
||||
.iter()
|
||||
.for_each(|(link_type, name)| {
|
||||
println!("cargo:rustc-link-lib={}={}", link_type, name);
|
||||
});
|
||||
|
||||
// Backend
|
||||
BACKEND_DEPS.iter().for_each(|name| {
|
||||
println!("cargo:rustc-link-lib=static={}", name);
|
||||
});
|
||||
}
|
6
backends/trtllm/cmake/fmt.cmake
Normal file
6
backends/trtllm/cmake/fmt.cmake
Normal file
@ -0,0 +1,6 @@
|
||||
FetchContent_Declare(
|
||||
fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt
|
||||
GIT_TAG 11.0.1
|
||||
)
|
||||
FetchContent_MakeAvailable(fmt)
|
5
backends/trtllm/cmake/json.cmake
Normal file
5
backends/trtllm/cmake/json.cmake
Normal file
@ -0,0 +1,5 @@
|
||||
fetchcontent_declare(
|
||||
json
|
||||
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz
|
||||
)
|
||||
fetchcontent_makeavailable(json)
|
17
backends/trtllm/cmake/spdlog.cmake
Normal file
17
backends/trtllm/cmake/spdlog.cmake
Normal file
@ -0,0 +1,17 @@
|
||||
set(SPDLOG_USE_FMT ON)
|
||||
set(SPDLOG_BUILD_SHARED OFF)
|
||||
set(SPDLOG_FMT_EXTERNAL ON)
|
||||
|
||||
# Define the level at which SPDLOG_ compilation level is defined
|
||||
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
||||
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG)
|
||||
else ()
|
||||
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO)
|
||||
endif ()
|
||||
|
||||
fetchcontent_declare(
|
||||
spdlog
|
||||
GIT_REPOSITORY https://github.com/gabime/spdlog.git
|
||||
GIT_TAG v1.14.1
|
||||
)
|
||||
fetchcontent_makeavailable(spdlog)
|
42
backends/trtllm/cmake/trtllm.cmake
Normal file
42
backends/trtllm/cmake/trtllm.cmake
Normal file
@ -0,0 +1,42 @@
|
||||
set(TRT_INCLUDE_DIR ${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
|
||||
set(TRT_LIB_DIR ${TGI_TRTLLM_BACKEND_TRT_LIB_DIR})
|
||||
|
||||
set(USE_CXX11_ABI ON)
|
||||
set(BUILD_PYT OFF)
|
||||
set(BUILD_PYBIND OFF)
|
||||
set(BUILD_MICRO_BENCHMARKS OFF)
|
||||
set(BUILD_BENCHMARKS OFF)
|
||||
set(BUILD_TESTS OFF)
|
||||
set(CMAKE_CUDA_ARCHITECTURES ${TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST})
|
||||
|
||||
message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
||||
|
||||
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
||||
set(FAST_BUILD ON)
|
||||
set(NVTX_DISABLE OFF)
|
||||
else ()
|
||||
set(FAST_BUILD OFF)
|
||||
set(FAST_MATH ON)
|
||||
set(NVTX_DISABLE ON)
|
||||
endif ()
|
||||
|
||||
fetchcontent_declare(
|
||||
trtllm
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git
|
||||
GIT_TAG a681853d3803ee5893307e812530b5e7004bb6e1
|
||||
GIT_SHALLOW FALSE
|
||||
)
|
||||
fetchcontent_makeavailable(trtllm)
|
||||
|
||||
message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}")
|
||||
execute_process(COMMAND git lfs install WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/")
|
||||
execute_process(COMMAND git lfs pull WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/")
|
||||
|
||||
# TRTLLM use a JIT based *precompiled* library to generate some specific kernels, we are generating the path to this one here
|
||||
set(TRTLLM_NVRTC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_nvrtc_wrapper${CMAKE_SHARED_LIBRARY_SUFFIX}" CACHE INTERNAL "nvrtc wrapper library name")
|
||||
set(TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_NVRTC_LIBRARY_NAME}"
|
||||
CACHE INTERNAL "nvrtc wrapper library path")
|
||||
|
||||
# The same Executor Static library
|
||||
set(TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_executor_static${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE INTERNAL "executor_static library name")
|
||||
set(TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/executor/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME}" CACHE INTERNAL "executor_static library path")
|
0
backends/trtllm/cmake/utils/detect_cuda_arch.cu
Normal file
0
backends/trtllm/cmake/utils/detect_cuda_arch.cu
Normal file
121
backends/trtllm/include/backend.h
Normal file
121
backends/trtllm/include/backend.h
Normal file
@ -0,0 +1,121 @@
|
||||
//
|
||||
// Created by Morgan Funtowicz on 6/30/24.
|
||||
//
|
||||
|
||||
#ifndef TGI_TRTLLM_BACKEND_H
|
||||
#define TGI_TRTLLM_BACKEND_H
|
||||
|
||||
#include <cmath>
|
||||
#include <filesystem>
|
||||
#include <span>
|
||||
#include <vector>
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <tensorrt_llm/runtime/common.h>
|
||||
#include <tensorrt_llm/executor/executor.h>
|
||||
#include <tensorrt_llm/plugins/api/tllmPlugin.h>
|
||||
|
||||
using json = nlohmann::json;
|
||||
namespace tle = tensorrt_llm::executor;
|
||||
|
||||
namespace huggingface::tgi::backends {
|
||||
using RequestId = tle::IdType;
|
||||
using TokenId = tle::TokenIdType;
|
||||
|
||||
/**
|
||||
* Initialize all the components required by TRTLLM.
|
||||
* It is required to call this function before attempting to load any engine
|
||||
*/
|
||||
void InitializeBackend();
|
||||
|
||||
/**
|
||||
*
|
||||
* @param config TensorRT-LLM configuration object
|
||||
* @param workerPath Path to the "executorWorker" provided by TensorRT-LLM when using orchestrator mode
|
||||
* @return
|
||||
*/
|
||||
tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath);
|
||||
|
||||
/**
|
||||
* Get the sampling configuration from the parameters provided by TGI
|
||||
* @param topK
|
||||
* @param topP
|
||||
* @param temperature
|
||||
* @param repetition_penalty
|
||||
* @param frequency_penalty
|
||||
* @param seed
|
||||
* @return
|
||||
*/
|
||||
tle::SamplingConfig GetSamplingConfig(
|
||||
uint32_t topK,
|
||||
float_t topP,
|
||||
float_t temperature,
|
||||
float_t repetition_penalty,
|
||||
float_t frequency_penalty,
|
||||
uint64_t seed
|
||||
);
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
class TensorRtLlmBackend {
|
||||
private:
|
||||
const json config;
|
||||
tle::Executor executor;
|
||||
|
||||
public:
|
||||
explicit TensorRtLlmBackend(
|
||||
const std::filesystem::path &engineFolder,
|
||||
const std::filesystem::path &executorWorker
|
||||
);
|
||||
|
||||
/**
|
||||
* Indicate if the backend is ready to accept incoming request
|
||||
* @return true if ready, false otherwise
|
||||
*/
|
||||
[[nodiscard]] bool IsReady() const;
|
||||
|
||||
/**
|
||||
* Query the executor for the number of token available for pulling
|
||||
* @return
|
||||
*/
|
||||
[[nodiscard]] size_t NumResponsesReady() const;
|
||||
|
||||
/**
|
||||
* Submit a new generation task to the executor
|
||||
* @param tokens
|
||||
* @param topK
|
||||
* @param topP
|
||||
* @param temperature
|
||||
* @param repetition_penalty
|
||||
* @param frequency_penalty
|
||||
* @param seed
|
||||
* @return Request id related to this generation for reference
|
||||
*/
|
||||
[[nodiscard]] RequestId Submit(
|
||||
const std::vector<TokenId> &tokens,
|
||||
int32_t topK,
|
||||
float_t topP,
|
||||
float_t temperature,
|
||||
float_t repetition_penalty,
|
||||
float_t frequency_penalty,
|
||||
uint64_t seed
|
||||
);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param requestId The request id to poll the generation results
|
||||
* @return
|
||||
*/
|
||||
std::vector<tle::Response> Poll(RequestId requestId);
|
||||
|
||||
/**
|
||||
* Stop the underlying executor
|
||||
*/
|
||||
void Shutdown();
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
#endif //TGI_TRTLLM_BACKEND_H
|
75
backends/trtllm/include/ffi.h
Normal file
75
backends/trtllm/include/ffi.h
Normal file
@ -0,0 +1,75 @@
|
||||
//
|
||||
// Created by mfuntowicz on 7/11/24.
|
||||
//
|
||||
|
||||
#ifndef TGI_TRTLLM_BACKEND_FFI_H
|
||||
#define TGI_TRTLLM_BACKEND_FFI_H
|
||||
|
||||
#include <cstddef>
|
||||
#include "backend.h"
|
||||
|
||||
namespace huggingface::tgi::backends {
|
||||
class TensorRtLlmBackendImpl;
|
||||
}
|
||||
|
||||
#include "backends/trtllm/src/lib.rs.h"
|
||||
|
||||
|
||||
namespace huggingface::tgi::backends {
|
||||
|
||||
// struct GenerationContext;
|
||||
|
||||
class TensorRtLlmBackendImpl : public TensorRtLlmBackend {
|
||||
public:
|
||||
/***
|
||||
*
|
||||
* @param engineFolder
|
||||
* @param executorWorker
|
||||
*/
|
||||
TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker);
|
||||
|
||||
/***
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
bool IsReady() const;
|
||||
|
||||
/***
|
||||
*
|
||||
* @param tokens
|
||||
* @param topK
|
||||
* @param topP
|
||||
* @param temperature
|
||||
* @param repetition_penalty
|
||||
* @param frequency_penalty
|
||||
* @param seed
|
||||
* @return
|
||||
*/
|
||||
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
|
||||
uint64_t
|
||||
Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature,
|
||||
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed);
|
||||
|
||||
/***
|
||||
*
|
||||
* @param requestId
|
||||
* @param ctx
|
||||
* @param callback
|
||||
* @return
|
||||
*/
|
||||
size_t StreamTokens(
|
||||
const RequestId requestId,
|
||||
huggingface::tgi::backends::GenerationContext *ctx,
|
||||
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
|
||||
huggingface::tgi::backends::GenerationStep)> callback);
|
||||
};
|
||||
|
||||
/***
|
||||
*
|
||||
* @param engineFolder
|
||||
* @return
|
||||
*/
|
||||
std::unique_ptr<TensorRtLlmBackendImpl> CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker);
|
||||
}
|
||||
|
||||
#endif //TGI_TRTLLM_BACKEND_FFI_H
|
59
backends/trtllm/include/hardware.h
Normal file
59
backends/trtllm/include/hardware.h
Normal file
@ -0,0 +1,59 @@
|
||||
//
|
||||
// Created by mfuntowicz on 7/23/24.
|
||||
//
|
||||
|
||||
#ifndef TGI_TRTLLM_BACKEND_HARDWARE_H
|
||||
#define TGI_TRTLLM_BACKEND_HARDWARE_H
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <fmt/base.h>
|
||||
#include <spdlog/spdlog.h>
|
||||
#include <nvml.h>
|
||||
|
||||
namespace huggingface::hardware::cuda {
|
||||
|
||||
#define AMPERE_SM_MAJOR 8
|
||||
#define HOPPER_SM_MAJOR 8
|
||||
|
||||
/**
|
||||
* Store information about the version of the CUDA Compute Capabilities detected on the device
|
||||
*/
|
||||
struct CudaComputeCapabilities {
|
||||
int32_t major;
|
||||
int32_t minor;
|
||||
|
||||
[[nodiscard]] constexpr bool isPostAmpere() const { return major >= AMPERE_SM_MAJOR; }
|
||||
|
||||
[[nodiscard]] constexpr bool isPostHopper() const { return major >= HOPPER_SM_MAJOR; }
|
||||
};
|
||||
|
||||
CudaComputeCapabilities GetCudaComputeCapabilities() {
|
||||
// Get the compute capabilities of the current hardware
|
||||
nvmlDevice_t device;
|
||||
CudaComputeCapabilities capabilities{0, 0};
|
||||
if (nvmlDeviceGetHandleByIndex_v2(0, &device) == NVML_SUCCESS) {
|
||||
SPDLOG_DEBUG("Successfully acquired nvmlDevice_t = 0");
|
||||
if (nvmlDeviceGetCudaComputeCapability(device, &capabilities.major, &capabilities.minor) == NVML_SUCCESS) {
|
||||
SPDLOG_INFO("Detected sm_{:d}{:d} compute capabilities", capabilities.major, capabilities.minor);
|
||||
}
|
||||
}
|
||||
|
||||
return capabilities;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the number of GPU detected. If no GPU is detected, return size_t::max()
|
||||
* @return
|
||||
*/
|
||||
std::optional<size_t> GetNumDevices() {
|
||||
uint32_t numGpus = 0;
|
||||
if (nvmlDeviceGetCount_v2(&numGpus) == NVML_SUCCESS) {
|
||||
return std::optional(numGpus);
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif //TGI_TRTLLM_BACKEND_HARDWARE_H
|
146
backends/trtllm/lib/backend.cpp
Normal file
146
backends/trtllm/lib/backend.cpp
Normal file
@ -0,0 +1,146 @@
|
||||
#include <fstream>
|
||||
|
||||
#include <fmt/ranges.h>
|
||||
#include <spdlog/spdlog.h>
|
||||
#include <nvml.h>
|
||||
|
||||
#include "backend.h"
|
||||
#include "hardware.h"
|
||||
|
||||
void huggingface::tgi::backends::InitializeBackend() {
|
||||
SPDLOG_INFO("Initializing Backend...");
|
||||
nvmlInit_v2();
|
||||
initTrtLlmPlugins();
|
||||
|
||||
const auto numGpus = huggingface::hardware::cuda::GetNumDevices();
|
||||
if (numGpus.has_value()) {
|
||||
SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value());
|
||||
} else {
|
||||
SPDLOG_WARN("Failed to detected Nvidia GPU(s) on the system");
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
|
||||
tle::ExecutorConfig execConfig(1);
|
||||
|
||||
// Retrieve the compute capabilities to enable some options at runtime
|
||||
const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities();
|
||||
|
||||
// Single engine (TP = PP = 1) -> using leader mode (no MPI involved)
|
||||
if (config["/pretrained_config/mapping/world_size"_json_pointer].get<uint8_t>() == 1) {
|
||||
SPDLOG_INFO("Detected single engine deployment, using leader mode");
|
||||
execConfig.setParallelConfig(tle::ParallelConfig(
|
||||
tle::CommunicationType::kMPI,
|
||||
tle::CommunicationMode::kLEADER,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
std::nullopt
|
||||
));
|
||||
} else { // Multiple engines -> using orchestrator mode (MPI involved)
|
||||
SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
|
||||
execConfig.setParallelConfig(tle::ParallelConfig(
|
||||
tle::CommunicationType::kMPI,
|
||||
tle::CommunicationMode::kORCHESTRATOR,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
tle::OrchestratorConfig(true, workerPath, nullptr, true)
|
||||
));
|
||||
}
|
||||
|
||||
// Define some configuration variables
|
||||
execConfig.setKvCacheConfig(tle::KvCacheConfig(true));
|
||||
execConfig.setEnableChunkedContext(computeCapabilities.isPostAmpere());
|
||||
return execConfig;
|
||||
}
|
||||
|
||||
tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
|
||||
uint32_t topK,
|
||||
float_t topP,
|
||||
float_t temperature,
|
||||
float_t repetition_penalty,
|
||||
float_t frequency_penalty,
|
||||
uint64_t seed) {
|
||||
return tle::SamplingConfig(
|
||||
1, // TGI only use a single beam
|
||||
topK,
|
||||
topP,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
seed,
|
||||
temperature,
|
||||
temperature,
|
||||
std::nullopt,
|
||||
repetition_penalty,
|
||||
std::nullopt,
|
||||
frequency_penalty
|
||||
);
|
||||
}
|
||||
|
||||
huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
|
||||
const std::filesystem::path &enginesFolder,
|
||||
const std::filesystem::path &executorWorker
|
||||
) :
|
||||
config(json::parse(std::ifstream(enginesFolder / "config.json"))),
|
||||
executor(
|
||||
enginesFolder,
|
||||
tensorrt_llm::executor::ModelType::kDECODER_ONLY,
|
||||
GetExecutorConfig(config, executorWorker.string()
|
||||
)) {
|
||||
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
|
||||
}
|
||||
|
||||
bool huggingface::tgi::backends::TensorRtLlmBackend::IsReady() const {
|
||||
return executor.canEnqueueRequests();
|
||||
}
|
||||
|
||||
[[nodiscard("Returned number of requests needs to be consumed")]]
|
||||
size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const {
|
||||
return executor.getNumResponsesReady();
|
||||
}
|
||||
|
||||
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
|
||||
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
||||
const std::vector<tle::TokenIdType> &tokens,
|
||||
const int32_t topK,
|
||||
const float_t topP,
|
||||
const float_t temperature,
|
||||
const float_t repetition_penalty,
|
||||
const float_t frequency_penalty,
|
||||
const uint64_t seed
|
||||
) {
|
||||
#ifdef NDEBUG
|
||||
SPDLOG_DEBUG(
|
||||
FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"),
|
||||
tokens.size(),
|
||||
executor.getLatestIterationStats().back().numActiveRequests
|
||||
);
|
||||
#else
|
||||
SPDLOG_DEBUG(
|
||||
FMT_STRING("Submitting inference [{}] to the executor ({:d} already in-flight)"),
|
||||
fmt::join(tokens, ", "),
|
||||
executor.getLatestIterationStats().front().numActiveRequests
|
||||
);
|
||||
#endif
|
||||
|
||||
const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<size_t>();
|
||||
const auto maxNewTokens = static_cast<int32_t>(std::max(1ul, maxNumTokens - tokens.size()));
|
||||
|
||||
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
||||
const auto output = tle::OutputConfig(true, false, false, true, false);
|
||||
return executor.enqueueRequest(
|
||||
tle::Request{tokens, maxNewTokens, true, sampling, output});
|
||||
}
|
||||
|
||||
[[nodiscard("Generated tokens result must be used")]]
|
||||
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType requestId) {
|
||||
SPDLOG_DEBUG(FMT_STRING("Polling status for request {:d}"), requestId);
|
||||
return executor.awaitResponses(requestId);
|
||||
}
|
||||
|
||||
|
||||
void huggingface::tgi::backends::TensorRtLlmBackend::Shutdown() {
|
||||
SPDLOG_INFO("Shutting down executor");
|
||||
executor.shutdown();
|
||||
}
|
111
backends/trtllm/scripts/install_tensorrt.sh
Executable file
111
backends/trtllm/scripts/install_tensorrt.sh
Executable file
@ -0,0 +1,111 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -ex
|
||||
|
||||
TRT_VER="10.2.0.19"
|
||||
CUDA_VER="12.5"
|
||||
CUDNN_VER="9.2.1.18-1"
|
||||
NCCL_VER="2.22.3-1+cuda12.5"
|
||||
CUBLAS_VER="12.5.3.2-1"
|
||||
NVRTC_VER="12.5.82-1"
|
||||
|
||||
for i in "$@"; do
|
||||
case $i in
|
||||
--TRT_VER=?*) TRT_VER="${i#*=}";;
|
||||
--CUDA_VER=?*) CUDA_VER="${i#*=}";;
|
||||
--CUDNN_VER=?*) CUDNN_VER="${i#*=}";;
|
||||
--NCCL_VER=?*) NCCL_VER="${i#*=}";;
|
||||
--CUBLAS_VER=?*) CUBLAS_VER="${i#*=}";;
|
||||
*) ;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
|
||||
NVCC_VERSION_OUTPUT=$(nvcc --version)
|
||||
if [[ $(echo $NVCC_VERSION_OUTPUT | grep -oP "\d+\.\d+" | head -n 1) != ${CUDA_VER} ]]; then
|
||||
echo "The version of pre-installed CUDA is not equal to ${CUDA_VER}."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
install_ubuntu_requirements() {
|
||||
apt-get update && apt-get install -y --no-install-recommends gnupg2 curl ca-certificates
|
||||
ARCH=$(uname -m)
|
||||
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
|
||||
if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi
|
||||
curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/${ARCH}/cuda-keyring_1.0-1_all.deb
|
||||
dpkg -i cuda-keyring_1.0-1_all.deb
|
||||
|
||||
apt-get update
|
||||
if [[ $(apt list --installed | grep libcudnn9) ]]; then
|
||||
apt-get remove --purge -y --allow-change-held-packages libcudnn9*
|
||||
fi
|
||||
if [[ $(apt list --installed | grep libnccl) ]]; then
|
||||
apt-get remove --purge -y --allow-change-held-packages libnccl*
|
||||
fi
|
||||
if [[ $(apt list --installed | grep libcublas) ]]; then
|
||||
apt-get remove --purge -y --allow-change-held-packages libcublas*
|
||||
fi
|
||||
if [[ $(apt list --installed | grep cuda-nvrtc-dev) ]]; then
|
||||
apt-get remove --purge -y --allow-change-held-packages cuda-nvrtc-dev*
|
||||
fi
|
||||
CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
|
||||
apt-get install -y --no-install-recommends libcudnn9-cuda-12=${CUDNN_VER} libcudnn9-dev-cuda-12=${CUDNN_VER}
|
||||
apt-get install -y --no-install-recommends libnccl2=${NCCL_VER} libnccl-dev=${NCCL_VER}
|
||||
apt-get install -y --no-install-recommends libcublas-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} libcublas-dev-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER}
|
||||
# NVRTC static library doesn't exist in NGC PyTorch container.
|
||||
NVRTC_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
|
||||
apt-get install -y --no-install-recommends cuda-nvrtc-dev-${NVRTC_CUDA_VERSION}=${NVRTC_VER}
|
||||
apt-get clean
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
}
|
||||
|
||||
install_centos_requirements() {
|
||||
CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
|
||||
yum -y update
|
||||
yum -y install epel-release
|
||||
yum remove -y libnccl* && yum -y install libnccl-${NCCL_VER} libnccl-devel-${NCCL_VER}
|
||||
yum remove -y libcublas* && yum -y install libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER} libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}
|
||||
yum clean all
|
||||
}
|
||||
|
||||
install_tensorrt() {
|
||||
#PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
|
||||
#PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
|
||||
TRT_CUDA_VERSION="12.5"
|
||||
|
||||
if [ -z "$RELEASE_URL_TRT" ];then
|
||||
ARCH=${TRT_TARGETARCH}
|
||||
if [ -z "$ARCH" ];then ARCH=$(uname -m);fi
|
||||
if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi
|
||||
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
|
||||
if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi
|
||||
if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-22.04" && OS="ubuntu-22.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi
|
||||
RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.2.0/tars/TensorRT-${TRT_VER}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz
|
||||
fi
|
||||
wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
|
||||
tar -xf /tmp/TensorRT.tar -C /usr/local/
|
||||
mv /usr/local/TensorRT-${TRT_VER} /usr/local/tensorrt
|
||||
# pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl
|
||||
rm -rf /tmp/TensorRT.tar
|
||||
}
|
||||
|
||||
# Install base packages depending on the base OS
|
||||
ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"')
|
||||
case "$ID" in
|
||||
debian)
|
||||
install_ubuntu_requirements
|
||||
install_tensorrt
|
||||
;;
|
||||
ubuntu)
|
||||
install_ubuntu_requirements
|
||||
install_tensorrt
|
||||
;;
|
||||
centos)
|
||||
install_centos_requirements
|
||||
install_tensorrt
|
||||
;;
|
||||
*)
|
||||
echo "Unable to determine OS..."
|
||||
exit 1
|
||||
;;
|
||||
esac
|
330
backends/trtllm/src/backend.rs
Normal file
330
backends/trtllm/src/backend.rs
Normal file
@ -0,0 +1,330 @@
|
||||
use std::future::Future;
|
||||
use std::path::Path;
|
||||
use std::pin::{pin, Pin};
|
||||
use std::str::FromStr;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use cxx::UniquePtr;
|
||||
use log::{error, warn};
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
||||
use tokio::time::{sleep, Instant};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tokio_stream::{Stream, StreamExt};
|
||||
use tracing::{instrument, span, Level};
|
||||
|
||||
// use tokio::sync::RwLock;
|
||||
use parking_lot::RwLock;
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidationError::UnsupportedModality;
|
||||
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError};
|
||||
use text_generation_router::{FinishReason, Token};
|
||||
|
||||
use crate::errors::TensorRtLlmBackendError;
|
||||
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
|
||||
|
||||
// Value used to poll the state of the generation stream
|
||||
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
|
||||
|
||||
type InferResult<T> = Result<T, InferError>;
|
||||
|
||||
pub(crate) struct Generation {
|
||||
executor: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
|
||||
done: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
/// Holds the user provided input to be executed along with a channel allowing
|
||||
/// to bubble up all the generated tokens for that tokens the to end stream.
|
||||
pub struct GenerationContext {
|
||||
sender: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
tokens: Vec<u32>,
|
||||
done: Arc<AtomicBool>,
|
||||
queued: Instant,
|
||||
start: Option<Instant>,
|
||||
}
|
||||
|
||||
impl Stream for Generation {
|
||||
type Item = usize;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let interval = POLLING_INTERVAL_US.get_or_init(|| {
|
||||
u64::from_str(option_env!("TRTLLM_BACKEND_POLLING_INTERVAL_US").unwrap_or("100"))
|
||||
.expect("Invalid value provided for envvar POLLING_INTERVAL_US")
|
||||
});
|
||||
|
||||
if !self.done.load(Ordering::Relaxed) {
|
||||
let backend = pin!(self.executor.read());
|
||||
let status = match backend.poll(ctx) {
|
||||
Poll::Ready(executor_r) => {
|
||||
let ready = executor_r.num_responses_ready();
|
||||
if ready == 0 {
|
||||
Poll::Pending
|
||||
} else {
|
||||
Poll::Ready(Some(ready))
|
||||
}
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
};
|
||||
|
||||
let waker = ctx.waker().clone();
|
||||
tokio::spawn(async {
|
||||
sleep(Duration::from_micros(*interval)).await;
|
||||
waker.wake();
|
||||
});
|
||||
|
||||
status
|
||||
} else {
|
||||
Poll::Ready(None) // end of stream
|
||||
}
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
(1, None)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Send for TensorRtLlmBackendImpl {}
|
||||
unsafe impl Sync for TensorRtLlmBackendImpl {}
|
||||
|
||||
/// Implements the logic to execute generation with TensorRT-LLM executor API in background
|
||||
pub struct TensorRtLlmBackend {
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
|
||||
// Backing the backend behind a RwLock to allow concurrent read access to retrieve
|
||||
// the number of available tokens (read only) in the Generation stream
|
||||
backend: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
|
||||
}
|
||||
|
||||
impl TensorRtLlmBackend {
|
||||
pub fn new<P: AsRef<Path> + Send + 'static, PP: AsRef<Path> + Send + 'static>(
|
||||
tokenizer: Tokenizer,
|
||||
engine_folder: P,
|
||||
executor_worker_path: PP,
|
||||
) -> Result<Self, TensorRtLlmBackendError> {
|
||||
Ok(TensorRtLlmBackend {
|
||||
tokenizer: Arc::new(tokenizer),
|
||||
backend: Arc::new(RwLock::new(create_tensorrt_llm_backend(
|
||||
engine_folder.as_ref().to_str().unwrap(),
|
||||
executor_worker_path.as_ref().to_str().unwrap(),
|
||||
))),
|
||||
})
|
||||
}
|
||||
|
||||
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
|
||||
if request.top_n_tokens > 1 {
|
||||
return Err(InferError::ValidationError(
|
||||
ValidationError::TopNTokensDisabled,
|
||||
));
|
||||
}
|
||||
|
||||
// TODO: Is it really needed? How can it be validated before?
|
||||
if request.parameters.grammar.is_some() {
|
||||
return Err(InferError::ValidationError(ValidationError::Grammar));
|
||||
}
|
||||
|
||||
match request.inputs.len() {
|
||||
0 => Err(InferError::ValidationError(ValidationError::EmptyInput)),
|
||||
2.. => Err(InferError::GenerationError(
|
||||
"TensorRT-LLM backend don't support multi-chunk".into(),
|
||||
)),
|
||||
1 => match request.inputs.first().expect("Single item-chunk") {
|
||||
Chunk::Text(text) => Ok(text),
|
||||
Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn generate(
|
||||
&self,
|
||||
sender: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||
tokens: Vec<u32>,
|
||||
top_k: u32,
|
||||
top_p: f32,
|
||||
temperature: f32,
|
||||
repetition_penalty: f32,
|
||||
frequency_penalty: f32,
|
||||
seed: u64,
|
||||
) {
|
||||
let tokenizer = Arc::clone(&self.tokenizer);
|
||||
let executor = Arc::clone(&self.backend);
|
||||
|
||||
// Let's push this in async context
|
||||
tokio::spawn(async move {
|
||||
// Define the generation state
|
||||
let mut generation = Generation {
|
||||
executor: executor.clone(),
|
||||
done: Arc::new(AtomicBool::new(false)),
|
||||
};
|
||||
|
||||
// Define the context over the generation
|
||||
// TODO(asap): Do we really need so many shared-ownership?
|
||||
let ctx = Box::new(GenerationContext {
|
||||
sender: sender.clone(),
|
||||
tokenizer,
|
||||
tokens: vec![],
|
||||
done: Arc::clone(&generation.done),
|
||||
start: None,
|
||||
queued: Instant::now(),
|
||||
});
|
||||
|
||||
// We are leaking the context on-purpose to avoid the box being dropped while there are
|
||||
// still computation ongoing
|
||||
// TODO(asap): Can we achieve the same with an Arc<Box<T>> without the need to go unsafe?
|
||||
let ctx_ = Box::leak(ctx);
|
||||
|
||||
// Submit the request to the batcher
|
||||
let request_id = span!(Level::DEBUG, "submit")
|
||||
.in_scope(|| async {
|
||||
let mut handle = executor.write().await;
|
||||
let request_id = handle.pin_mut().submit(
|
||||
&tokens,
|
||||
top_k as i32,
|
||||
top_p,
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
frequency_penalty,
|
||||
seed,
|
||||
);
|
||||
|
||||
request_id
|
||||
})
|
||||
.await;
|
||||
|
||||
while let Some(_) = generation.next().await {
|
||||
let mut executor_w = executor.write().await;
|
||||
let executor = executor_w.pin_mut();
|
||||
|
||||
span!(Level::DEBUG, "decode")
|
||||
.in_scope(|| async {
|
||||
unsafe {
|
||||
executor.stream_tokens(
|
||||
request_id,
|
||||
ctx_,
|
||||
|ctx: *mut GenerationContext, step: GenerationStep| {
|
||||
let inner_ctx = &mut *ctx;
|
||||
|
||||
// Update the timestamp at which the request started effectively
|
||||
// Can be a bit off, would need to be before the callback, let's see
|
||||
inner_ctx.start.get_or_insert(Instant::now());
|
||||
inner_ctx.done.store(step.is_final, Ordering::Relaxed);
|
||||
|
||||
// Ensure we are not running into errors
|
||||
let parcel = if !step.has_error {
|
||||
// Insert the latest generated token to the tracker
|
||||
inner_ctx.tokens.push(step.token_id);
|
||||
|
||||
// Decode the token
|
||||
let text = inner_ctx
|
||||
.tokenizer
|
||||
.decode(&[step.token_id], true)
|
||||
.expect("Failed to decode token");
|
||||
|
||||
let special = inner_ctx
|
||||
.tokenizer
|
||||
.get_added_vocabulary()
|
||||
.is_special_token(&text);
|
||||
|
||||
// Create the structure holding the token
|
||||
let token = Token {
|
||||
id: step.token_id,
|
||||
text,
|
||||
logprob: step.log_prob,
|
||||
special,
|
||||
};
|
||||
|
||||
if step.is_final {
|
||||
let generated_text = inner_ctx
|
||||
.tokenizer
|
||||
.decode(&inner_ctx.tokens, true)
|
||||
.expect("Failed to decode generated_tokens");
|
||||
|
||||
Ok(InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens: vec![],
|
||||
generated_text: GeneratedText {
|
||||
text: generated_text,
|
||||
generated_tokens: inner_ctx.tokens.len() as u32,
|
||||
finish_reason: FinishReason::EndOfSequenceToken,
|
||||
seed: None,
|
||||
},
|
||||
start: inner_ctx.start.unwrap_or(Instant::now()),
|
||||
queued: inner_ctx.queued,
|
||||
})
|
||||
} else {
|
||||
Ok(InferStreamResponse::Intermediate {
|
||||
token,
|
||||
top_tokens: vec![],
|
||||
})
|
||||
}
|
||||
} else {
|
||||
error!("Error caught while decoding: {}", &step.error_msg);
|
||||
Err(InferError::GenerationError(step.error_msg))
|
||||
};
|
||||
|
||||
// Send the parcel to the client
|
||||
inner_ctx
|
||||
.sender
|
||||
.send(parcel)
|
||||
.expect("Failed to sent msg through the channel");
|
||||
},
|
||||
);
|
||||
}
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
// "Properly" free the shared context...
|
||||
// TODO: clean that piece of sh** asap
|
||||
unsafe {
|
||||
let _ = Box::from_raw(ctx_);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Backend for TensorRtLlmBackend {
|
||||
#[instrument(skip_all)]
|
||||
fn schedule(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
|
||||
// Let's add a few more validation
|
||||
let input = TensorRtLlmBackend::validate(&request)?;
|
||||
|
||||
// Channel to stream the generated token as they come from the worker thread back to the transport layer
|
||||
let (sender, receiver) = unbounded_channel();
|
||||
|
||||
// Unpack parameters
|
||||
let params = &request.parameters;
|
||||
|
||||
// Preprocess the inputs to send to TRTLLM backend
|
||||
let encoding = self
|
||||
.tokenizer
|
||||
.encode(input.as_str(), true)
|
||||
.map_err(|e| InferError::GenerationError(e.to_string()))?;
|
||||
|
||||
// Generate the response
|
||||
self.generate(
|
||||
sender,
|
||||
Vec::from(encoding.get_ids()),
|
||||
params.top_k,
|
||||
params.top_p,
|
||||
params.temperature,
|
||||
params.repetition_penalty,
|
||||
params.frequency_penalty,
|
||||
params.seed,
|
||||
);
|
||||
|
||||
Ok(UnboundedReceiverStream::new(receiver))
|
||||
}
|
||||
|
||||
async fn health(&self, _current_health: bool) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
15
backends/trtllm/src/errors.rs
Normal file
15
backends/trtllm/src/errors.rs
Normal file
@ -0,0 +1,15 @@
|
||||
use thiserror::Error;
|
||||
|
||||
use text_generation_router::server;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TensorRtLlmBackendError {
|
||||
#[error("Tokenizer error: {0}")]
|
||||
Tokenizer(String),
|
||||
#[error("Argument validation error: {0}")]
|
||||
ArgumentValidation(String),
|
||||
#[error("WebServer error: {0}")]
|
||||
WebServer(#[from] server::WebServerError),
|
||||
#[error("Tokio runtime failed to start: {0}")]
|
||||
Tokio(#[from] std::io::Error),
|
||||
}
|
84
backends/trtllm/src/ffi.cpp
Normal file
84
backends/trtllm/src/ffi.cpp
Normal file
@ -0,0 +1,84 @@
|
||||
//
|
||||
// Created by mfuntowicz on 6/30/24.
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <exception>
|
||||
#include <filesystem>
|
||||
#include <limits>
|
||||
#include <iterator>
|
||||
#include <vector>
|
||||
|
||||
#include <spdlog/spdlog.h>
|
||||
#include "backends/trtllm/include/ffi.h"
|
||||
|
||||
|
||||
huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(
|
||||
const std::string_view &engineFolder,
|
||||
const std::string_view &executorWorker
|
||||
) : TensorRtLlmBackend(engineFolder, executorWorker) {}
|
||||
|
||||
|
||||
bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const {
|
||||
return TensorRtLlmBackend::IsReady();
|
||||
}
|
||||
|
||||
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
|
||||
rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty,
|
||||
float_t frequency_penalty, uint64_t seed) {
|
||||
|
||||
// This will copy all the items from the initial slice
|
||||
std::vector<int32_t> tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end()));
|
||||
return TensorRtLlmBackend::Submit(
|
||||
std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
||||
}
|
||||
|
||||
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
|
||||
const uint64_t requestId,
|
||||
huggingface::tgi::backends::GenerationContext *ctx,
|
||||
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
|
||||
huggingface::tgi::backends::GenerationStep)> callback) {
|
||||
|
||||
size_t numTokens = 0;
|
||||
for (const auto &item: Poll(requestId)) {
|
||||
GenerationStep step;
|
||||
if (!item.hasError()) {
|
||||
SPDLOG_DEBUG("\tStreamTokens -> Decoding token...");
|
||||
const auto decoded = item.getResult();
|
||||
|
||||
const auto token = decoded.outputTokenIds[0][0];
|
||||
const auto isFinal = decoded.isFinal;
|
||||
const auto logProb = decoded.logProbs.value()[0][0];
|
||||
|
||||
++numTokens;
|
||||
|
||||
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
|
||||
step = huggingface::tgi::backends::GenerationStep{
|
||||
static_cast<uint32_t>(token), logProb, isFinal, false, std::move(std::string())
|
||||
};
|
||||
SPDLOG_DEBUG("\tStreamTokens -> Post callback");
|
||||
} else {
|
||||
// TODO : Return rest::Result with error
|
||||
const auto what = item.getErrorMsg();
|
||||
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what);
|
||||
step = huggingface::tgi::backends::GenerationStep{
|
||||
std::numeric_limits<uint32_t>::max(), 0.0, true, true, std::move(what)
|
||||
};
|
||||
}
|
||||
|
||||
callback(std::move(ctx), std::move(step));
|
||||
}
|
||||
|
||||
return numTokens;
|
||||
}
|
||||
|
||||
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
|
||||
huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
|
||||
// Unconditionally call this to initialize and discover TRTLLM plugins
|
||||
InitializeBackend();
|
||||
|
||||
const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end());
|
||||
const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end());
|
||||
return std::make_unique<TensorRtLlmBackendImpl>(std::move(enginePath), std::move(executorPath));
|
||||
}
|
78
backends/trtllm/src/lib.rs
Normal file
78
backends/trtllm/src/lib.rs
Normal file
@ -0,0 +1,78 @@
|
||||
pub use backend::{GenerationContext, TensorRtLlmBackend};
|
||||
|
||||
mod backend;
|
||||
pub mod errors;
|
||||
|
||||
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
|
||||
mod ffi {
|
||||
|
||||
/// Struct used as shared type between rust and C++ to represent the result
|
||||
/// of a single decoding iteration
|
||||
pub struct GenerationStep {
|
||||
token_id: u32,
|
||||
log_prob: f32,
|
||||
is_final: bool,
|
||||
has_error: bool,
|
||||
error_msg: String,
|
||||
}
|
||||
|
||||
extern "Rust" {
|
||||
type GenerationContext;
|
||||
}
|
||||
|
||||
unsafe extern "C++" {
|
||||
include!("backends/trtllm/src/ffi.cpp");
|
||||
|
||||
/// Represent an instance of the underlying TensorRT-LLM backend
|
||||
type TensorRtLlmBackendImpl;
|
||||
|
||||
/// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `engine_folder`: Path to the folder containing all the TRTLLM engines
|
||||
/// * `executor_worker`: Path to the TRTLLM executor worker
|
||||
///
|
||||
/// returns: <unknown>
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
/// ```
|
||||
#[rust_name = "create_tensorrt_llm_backend"]
|
||||
fn CreateTensorRtLlmBackend(
|
||||
engine_folder: &str,
|
||||
executor_worker: &str,
|
||||
) -> UniquePtr<TensorRtLlmBackendImpl>;
|
||||
|
||||
// #[rust_name = "is_ready"]
|
||||
// fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;
|
||||
|
||||
#[rust_name = "num_responses_ready"]
|
||||
fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
|
||||
|
||||
#[rust_name = "submit"]
|
||||
fn Submit(
|
||||
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||
tokens: &[u32],
|
||||
top_k: i32,
|
||||
top_p: f32,
|
||||
temperature: f32,
|
||||
repetition_penalty: f32,
|
||||
frequency_penalty: f32,
|
||||
seed: u64,
|
||||
) -> u64;
|
||||
|
||||
#[rust_name = "stream_tokens"]
|
||||
unsafe fn StreamTokens(
|
||||
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||
request_id: u64,
|
||||
ctx: *mut GenerationContext,
|
||||
cb: unsafe fn(*mut GenerationContext, GenerationStep),
|
||||
) -> usize;
|
||||
|
||||
// #[rust_name = "shutdown"]
|
||||
// fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>);
|
||||
}
|
||||
}
|
166
backends/trtllm/src/main.rs
Normal file
166
backends/trtllm/src/main.rs
Normal file
@ -0,0 +1,166 @@
|
||||
use clap::Parser;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
|
||||
use text_generation_backends_trtllm::TensorRtLlmBackend;
|
||||
use text_generation_router::server;
|
||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
max_best_of: usize,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_top_n_tokens: u32,
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
max_input_tokens: usize,
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_prefill_tokens: u32,
|
||||
#[clap(long, env)]
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
hostname: String,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
#[clap(long, env, required = true)]
|
||||
tokenizer_name: String,
|
||||
#[clap(long, env)]
|
||||
tokenizer_config_path: Option<String>,
|
||||
#[clap(long, env)]
|
||||
revision: Option<String>,
|
||||
#[clap(long, env)]
|
||||
model_id: String,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
validation_workers: usize,
|
||||
#[clap(long, env)]
|
||||
json_output: bool,
|
||||
#[clap(long, env)]
|
||||
otlp_endpoint: Option<String>,
|
||||
#[clap(default_value = "text-generation-inference.router", long, env)]
|
||||
otlp_service_name: String,
|
||||
#[clap(long, env)]
|
||||
cors_allow_origin: Option<Vec<String>>,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
messages_api_enabled: bool,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
#[clap(long, env)]
|
||||
auth_token: Option<String>,
|
||||
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
|
||||
executor_worker: PathBuf,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), TensorRtLlmBackendError> {
|
||||
// Get args
|
||||
let args = Args::parse();
|
||||
// Pattern match configuration
|
||||
let Args {
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
hostname,
|
||||
port,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
model_id,
|
||||
validation_workers,
|
||||
json_output,
|
||||
otlp_endpoint,
|
||||
otlp_service_name,
|
||||
cors_allow_origin,
|
||||
messages_api_enabled,
|
||||
max_client_batch_size,
|
||||
auth_token,
|
||||
executor_worker,
|
||||
} = args;
|
||||
|
||||
// Launch Tokio runtime
|
||||
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||
|
||||
// Validate args
|
||||
if max_input_tokens >= max_total_tokens {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(
|
||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||
}
|
||||
|
||||
if validation_workers == 0 {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(
|
||||
"`validation_workers` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
}
|
||||
|
||||
if !executor_worker.exists() {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(format!(
|
||||
"`executor_work` specified path doesn't exists: {}",
|
||||
executor_worker.display()
|
||||
)));
|
||||
}
|
||||
|
||||
// Run server
|
||||
let tokenizer = Tokenizer::from_pretrained(
|
||||
tokenizer_name.clone(),
|
||||
Some(FromPretrainedParameters {
|
||||
revision: revision.clone().unwrap_or(String::from("main")),
|
||||
user_agent: HashMap::new(),
|
||||
auth_token,
|
||||
}),
|
||||
)
|
||||
.map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?;
|
||||
|
||||
let backend = TensorRtLlmBackend::new(tokenizer, model_id, executor_worker)?;
|
||||
server::run(
|
||||
backend,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
validation_workers,
|
||||
None,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
hostname,
|
||||
port,
|
||||
cors_allow_origin,
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
messages_api_enabled,
|
||||
true,
|
||||
max_client_batch_size,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
14
backends/trtllm/tests/infer_test.cpp
Normal file
14
backends/trtllm/tests/infer_test.cpp
Normal file
@ -0,0 +1,14 @@
|
||||
//
|
||||
// Created by mfuntowicz on 7/2/24.
|
||||
//
|
||||
#include <catch2/catch_all.hpp>
|
||||
#include <spdlog/spdlog.h>
|
||||
#include "../include/backend.h"
|
||||
|
||||
TEST_CASE("Load TRTLLM Engine on the TGI Backend", "[trtllm][engine][load]") {
|
||||
const auto engines = std::filesystem::path("/home/mfuntowicz/.cache/huggingface/assets/trtllm/0.11.0.dev2024062500/meta-llama--Meta-Llama-3-8B-Instruct/4090/engines/");
|
||||
const auto executor = std::filesystem::path("/home/mfuntowicz/Workspace/text-generation-inference/backends/trtllm/cmake-build-debug/cmake-build-debug/_deps/trtllm-src/cpp/tensorrt_llm/executor_worker/executorWorker");
|
||||
|
||||
spdlog::info("Loading config from: {}", absolute(engines).string());
|
||||
huggingface::tgi::backends::TensorRtLlmBackend backend(engines, executor);
|
||||
}
|
75
backends/v2/Cargo.toml
Normal file
75
backends/v2/Cargo.toml
Normal file
@ -0,0 +1,75 @@
|
||||
[package]
|
||||
name = "text-generation-router-v2"
|
||||
description = "Text Generation Webserver"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "text-generation-router-v2"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.74"
|
||||
async-stream = "0.3.5"
|
||||
axum = { version = "0.7", features = ["json"] }
|
||||
axum-tracing-opentelemetry = "0.16"
|
||||
text-generation-router = { path = "../../router" }
|
||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||
grpc-metadata = { path = "../grpc-metadata" }
|
||||
futures = "0.3.28"
|
||||
hf-hub = { workspace = true }
|
||||
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||
metrics = { workspace = true }
|
||||
metrics-exporter-prometheus = { workspace = true }
|
||||
nohash-hasher = "0.2.0"
|
||||
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
||||
opentelemetry-otlp = "0.13.0"
|
||||
rand = "0.8.5"
|
||||
reqwest = { version = "0.11.20", features = [] }
|
||||
serde = "1.0.188"
|
||||
serde_json = "1.0.107"
|
||||
slotmap = "1.0.7"
|
||||
thiserror = "1.0.48"
|
||||
tokenizers = { workspace = true }
|
||||
tokio = { version = "1.32.0", features = [
|
||||
"rt",
|
||||
"rt-multi-thread",
|
||||
"parking_lot",
|
||||
"signal",
|
||||
"sync",
|
||||
] }
|
||||
tokio-stream = "0.1.14"
|
||||
tower-http = { version = "0.5.1", features = ["cors"] }
|
||||
tracing = "0.1.37"
|
||||
tracing-opentelemetry = "0.21.0"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
||||
"opentelemetry-otlp",
|
||||
] }
|
||||
minijinja = { workspace = true }
|
||||
minijinja-contrib = { workspace = true }
|
||||
futures-util = "0.3.30"
|
||||
regex = "1.10.3"
|
||||
once_cell = "1.19.0"
|
||||
image = "0.25.1"
|
||||
base64 = { workspace = true }
|
||||
prost = "^0.12"
|
||||
tonic = "^0.10"
|
||||
tower = "^0.4"
|
||||
|
||||
[build-dependencies]
|
||||
tonic-build = "0.10.1"
|
||||
prost-build = "0.12.1"
|
||||
|
||||
[features]
|
||||
default = ["ngrok"]
|
||||
ngrok = ["text-generation-router/ngrok"]
|
||||
google = ["text-generation-router/google"]
|
||||
kserve = ["text-generation-router/kserve"]
|
@ -1,16 +1,16 @@
|
||||
use std::fs;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("cargo:rerun-if-changed=../../proto/generate.proto");
|
||||
fs::create_dir("src/pb").unwrap_or(());
|
||||
println!("cargo:rerun-if-changed=../../proto/");
|
||||
|
||||
fs::create_dir_all("src/client/pb").unwrap_or(());
|
||||
let mut config = prost_build::Config::new();
|
||||
config.protoc_arg("--experimental_allow_proto3_optional");
|
||||
|
||||
tonic_build::configure()
|
||||
.build_client(true)
|
||||
.build_server(false)
|
||||
.out_dir("src/pb")
|
||||
.out_dir("src/client/pb")
|
||||
.include_file("mod.rs")
|
||||
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
|
||||
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
517
backends/v2/src/backend.rs
Normal file
517
backends/v2/src/backend.rs
Normal file
@ -0,0 +1,517 @@
|
||||
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
|
||||
/// Batching and inference logic
|
||||
use crate::queue::{Entry, Queue};
|
||||
use async_trait::async_trait;
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::Arc;
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
|
||||
pub struct BackendV2 {
|
||||
/// Request queue
|
||||
queue: Queue,
|
||||
/// Notify batcher on queue appends
|
||||
batching_task_notifier: Arc<Notify>,
|
||||
/// Client clone, used for health checks to skip the queue
|
||||
client: ShardedClient,
|
||||
}
|
||||
|
||||
impl BackendV2 {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn new(
|
||||
client: ShardedClient,
|
||||
waiting_served_ratio: f32,
|
||||
max_input_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
requires_padding: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
) -> Self {
|
||||
// Infer shared state
|
||||
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
|
||||
attention
|
||||
.parse()
|
||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
|
||||
} else {
|
||||
Attention::Paged
|
||||
};
|
||||
let block_size = if attention == Attention::FlashDecoding {
|
||||
256
|
||||
} else {
|
||||
16
|
||||
};
|
||||
|
||||
let queue = Queue::new(
|
||||
requires_padding,
|
||||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
);
|
||||
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
// Spawn batching background task that contains all the inference logic
|
||||
tokio::spawn(batching_task(
|
||||
client.clone(),
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
queue.clone(),
|
||||
batching_task_notifier.clone(),
|
||||
));
|
||||
|
||||
Self {
|
||||
queue,
|
||||
batching_task_notifier,
|
||||
client,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Backend for BackendV2 {
|
||||
#[instrument(skip_all)]
|
||||
fn schedule(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||
// MPSC channel to communicate with the background batching task
|
||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||
|
||||
// Append the request to the queue
|
||||
self.queue.append(Entry {
|
||||
request,
|
||||
response_tx,
|
||||
span: Span::current(),
|
||||
temp_span: None,
|
||||
queue_time: Instant::now(),
|
||||
batch_time: None,
|
||||
});
|
||||
|
||||
// Notify the background task that we have a new entry in the queue that needs
|
||||
// to be batched
|
||||
self.batching_task_notifier.notify_one();
|
||||
|
||||
// Return stream
|
||||
Ok(UnboundedReceiverStream::new(response_rx))
|
||||
}
|
||||
|
||||
async fn health(&self, current_health: bool) -> bool {
|
||||
if current_health {
|
||||
// Generation is healthy, we only check that the shards can allocate on device
|
||||
self.client.device_health().await
|
||||
} else {
|
||||
self.client.model_health().await
|
||||
}
|
||||
.is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
/// Batching logic
|
||||
/// Will be launched in a background Tokio task
|
||||
///
|
||||
/// Batches requests and sends them to the inference server
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn batching_task(
|
||||
mut client: ShardedClient,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
queue: Queue,
|
||||
notifier: Arc<Notify>,
|
||||
) {
|
||||
// Infinite loop
|
||||
loop {
|
||||
// Wait for a notification from the Infer struct
|
||||
notifier.notified().await;
|
||||
|
||||
// Get the next batch from the queue
|
||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||
// waiting in the queue
|
||||
while let Some((mut entries, batch, span)) = queue
|
||||
.next_batch(
|
||||
None,
|
||||
max_batch_size,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
let mut waiting_tokens = 1;
|
||||
|
||||
// We loop until we do not receive any cached batch from the inference server (== until
|
||||
// all requests have met their stopping criteria)
|
||||
while let Some(batch) = cached_batch {
|
||||
// Get current batch info
|
||||
let batch_size = batch.size;
|
||||
let batch_max_tokens = batch.max_tokens;
|
||||
let mut batches = vec![batch];
|
||||
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||
|
||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||
// to add a new batch even though its size might be small
|
||||
None
|
||||
} else {
|
||||
// Minimum batch size
|
||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||
};
|
||||
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
let max_size =
|
||||
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
||||
// Try to get a new batch
|
||||
if let Some((mut new_entries, new_batch, span)) = queue
|
||||
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
||||
.await
|
||||
{
|
||||
// Tracking metrics
|
||||
if min_size.is_some() {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||
.increment(1);
|
||||
} else {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to add the info that this entry is waiting
|
||||
// because a new batch is being computed
|
||||
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
|
||||
// Add relationships
|
||||
span.follows_from(&entry_waiting_span);
|
||||
entry_waiting_span.follows_from(&span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_waiting_span);
|
||||
});
|
||||
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
// Reset waiting counter
|
||||
waiting_tokens = 1;
|
||||
// Extend current batch with the new batch
|
||||
if let Some(new_cached_batch) = new_cached_batch {
|
||||
entries.extend(new_entries);
|
||||
batches.push(new_cached_batch);
|
||||
}
|
||||
}
|
||||
|
||||
// Create span for this batch to add context to inference calls
|
||||
let next_batch_size = entries.len();
|
||||
let next_batch_span =
|
||||
info_span!(parent: None, "batch", batch_size = next_batch_size);
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
// Add relationships
|
||||
next_batch_span.follows_from(&entry_batch_span);
|
||||
entry_batch_span.follows_from(&next_batch_span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
});
|
||||
|
||||
cached_batch = decode(&mut client, batches, &mut entries)
|
||||
.instrument(next_batch_span)
|
||||
.await;
|
||||
waiting_tokens += 1;
|
||||
}
|
||||
metrics::gauge!("tgi_batch_current_size").set(0.0);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn prefill(
|
||||
client: &mut ShardedClient,
|
||||
batch: Batch,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_id = batch.id;
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
||||
|
||||
match client.prefill(batch).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||
.record(timings.decode.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_inference_duration","method" => "prefill")
|
||||
.record(start_time.elapsed().as_secs_f64());
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
let _ = client.clear_cache(Some(batch_id)).await;
|
||||
send_errors(err, entries);
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn decode(
|
||||
client: &mut ShardedClient,
|
||||
batches: Vec<CachedBatch>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
||||
|
||||
match client.decode(batches).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
if let Some(concat_duration) = timings.concat {
|
||||
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||
.record(concat_duration.as_secs_f64());
|
||||
}
|
||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
|
||||
.record(timings.decode.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
|
||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
|
||||
.record(start_time.elapsed().as_secs_f64());
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
for id in batch_ids {
|
||||
let _ = client.clear_cache(Some(id)).await;
|
||||
}
|
||||
send_errors(err, entries);
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Filter a `batch` and remove all requests not present in `entries`
|
||||
#[instrument(skip_all)]
|
||||
async fn filter_batch(
|
||||
client: &mut ShardedClient,
|
||||
next_batch: Option<CachedBatch>,
|
||||
entries: &IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
let mut batch = next_batch?;
|
||||
|
||||
// No need to filter
|
||||
if batch.size as usize == entries.len() {
|
||||
return Some(batch);
|
||||
}
|
||||
|
||||
let id = batch.id;
|
||||
|
||||
// Retain only requests that are still in entries
|
||||
batch.request_ids.retain(|id| entries.contains_key(id));
|
||||
|
||||
if batch.request_ids.is_empty() {
|
||||
// All requests have been filtered out
|
||||
// Next batch is now empty
|
||||
// Clear it from the Python shards cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client.clear_cache(Some(id)).await.unwrap();
|
||||
None
|
||||
} else {
|
||||
// Filter Python shard cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client.filter_batch(id, batch.request_ids).await.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||
/// and filter entries
|
||||
#[instrument(skip_all)]
|
||||
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
||||
generations.into_iter().for_each(|generation| {
|
||||
let id = generation.request_id;
|
||||
// Get entry
|
||||
// We can `expect` here as the request id should always be in the entries
|
||||
let entry = entries
|
||||
.get(&id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||
// Send generation responses back to the infer task
|
||||
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||
// request and we need to stop generating hence why we unwrap_or(true)
|
||||
let stopped = send_responses(generation, entry).inspect_err(|_err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
}).unwrap_or(true);
|
||||
if stopped {
|
||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Send responses through the `entry` response channel
|
||||
fn send_responses(
|
||||
generation: Generation,
|
||||
entry: &Entry,
|
||||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||
// Return directly if the channel is disconnected
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
let mut stopped = false;
|
||||
|
||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||
// Create Token objects
|
||||
// We do that here instead of in the Python code as Rust for loops are faster
|
||||
let prefill_tokens = prefill_tokens
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(prefill_tokens.logprobs)
|
||||
.zip(prefill_tokens.texts)
|
||||
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||
.collect();
|
||||
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
||||
}
|
||||
|
||||
// Create last Token
|
||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||
let n = tokens_.ids.len();
|
||||
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
||||
let mut iterator = tokens_
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(tokens_.logprobs)
|
||||
.zip(tokens_.texts)
|
||||
.zip(tokens_.is_special)
|
||||
.enumerate()
|
||||
.peekable();
|
||||
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
|
||||
let token = Token {
|
||||
id,
|
||||
text,
|
||||
logprob,
|
||||
special,
|
||||
};
|
||||
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
|
||||
top_tokens_
|
||||
.ids
|
||||
.iter()
|
||||
.zip(top_tokens_.logprobs.iter())
|
||||
.zip(top_tokens_.texts.iter())
|
||||
.zip(top_tokens_.is_special.iter())
|
||||
.map(|(((&id, &logprob), text), &special)| Token {
|
||||
id,
|
||||
text: text.to_string(),
|
||||
logprob,
|
||||
special,
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
match (&generation.generated_text, iterator.peek()) {
|
||||
(Some(generated_text), None) => {
|
||||
// Generation has ended
|
||||
stopped = true;
|
||||
// Send message
|
||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text: GeneratedText::from(generated_text.clone()),
|
||||
queued: entry.queue_time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
}))?;
|
||||
}
|
||||
_ => {
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(stopped)
|
||||
}
|
||||
|
||||
/// Send errors to Infer for all `entries`
|
||||
#[instrument(skip_all)]
|
||||
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||
entries.drain().for_each(|(_, entry)| {
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||
let err = InferError::GenerationError(error.to_string());
|
||||
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
entry
|
||||
.response_tx
|
||||
.send(Err(err))
|
||||
.unwrap_or(());
|
||||
});
|
||||
}
|
||||
|
||||
impl From<crate::client::GeneratedText> for GeneratedText {
|
||||
fn from(value: crate::client::GeneratedText) -> Self {
|
||||
let v2_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();
|
||||
let finish_reason = match v2_finish_reason {
|
||||
crate::client::FinishReason::Length => FinishReason::Length,
|
||||
crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||
crate::client::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||
};
|
||||
|
||||
Self {
|
||||
text: value.text,
|
||||
generated_tokens: value.generated_tokens,
|
||||
finish_reason,
|
||||
seed: value.seed,
|
||||
}
|
||||
}
|
||||
}
|
259
backends/v2/src/client/grpc_client.rs
Normal file
259
backends/v2/src/client/grpc_client.rs
Normal file
@ -0,0 +1,259 @@
|
||||
/// Single shard Client
|
||||
use crate::client::pb;
|
||||
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
|
||||
use grpc_metadata::InjectTelemetryContext;
|
||||
use pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
|
||||
use pb::generate::v2::*;
|
||||
use std::cmp::min;
|
||||
use std::time::Duration;
|
||||
use tonic::transport::{Channel, Uri};
|
||||
use tracing::instrument;
|
||||
|
||||
/// Text Generation Inference gRPC client
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Client {
|
||||
stub: TextGenerationServiceClient<Channel>,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
/// Returns a client connected to the given url
|
||||
#[allow(dead_code)]
|
||||
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||
let channel = Channel::builder(uri).connect().await?;
|
||||
|
||||
Ok(Self {
|
||||
stub: TextGenerationServiceClient::new(channel),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a client connected to the given unix socket
|
||||
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||
let channel = Channel::from_shared("http://[::]:50051".to_string())
|
||||
.unwrap()
|
||||
.connect_with_connector(tower::service_fn(move |_: Uri| {
|
||||
tokio::net::UnixStream::connect(path.clone())
|
||||
}))
|
||||
.await?;
|
||||
|
||||
Ok(Self {
|
||||
stub: TextGenerationServiceClient::new(channel),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a list of uris or unix sockets of all shards
|
||||
#[instrument(skip(self))]
|
||||
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
||||
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
|
||||
let response = self.stub.service_discovery(request).await.map_err(|_| {
|
||||
ClientError::Connection("Server does not support v2 interface".to_string())
|
||||
})?;
|
||||
let urls = response
|
||||
.into_inner()
|
||||
.urls
|
||||
.into_iter()
|
||||
// Remove unix socket prefix
|
||||
.map(|url| match url.strip_prefix("unix://") {
|
||||
None => url,
|
||||
Some(stripped_url) => stripped_url.to_string(),
|
||||
})
|
||||
.collect();
|
||||
Ok(urls)
|
||||
}
|
||||
|
||||
/// Get model info
|
||||
#[instrument(skip(self))]
|
||||
pub async fn info(&mut self) -> Result<InfoResponse> {
|
||||
let request = tonic::Request::new(InfoRequest {}).inject_context();
|
||||
let response = self.stub.info(request).await?.into_inner();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Get model health
|
||||
#[instrument(skip(self))]
|
||||
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||
let request = tonic::Request::new(HealthRequest {}).inject_context();
|
||||
let response = self.stub.health(request).await?.into_inner();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
|
||||
self.stub.clear_cache(request).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Filter a cached batch
|
||||
#[instrument(skip(self))]
|
||||
pub async fn filter_batch(
|
||||
&mut self,
|
||||
batch_id: u64,
|
||||
request_ids: Vec<u64>,
|
||||
) -> Result<Option<CachedBatch>> {
|
||||
let request = tonic::Request::new(FilterBatchRequest {
|
||||
batch_id,
|
||||
request_ids,
|
||||
})
|
||||
.inject_context();
|
||||
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
||||
Ok(filtered_batch.batch)
|
||||
}
|
||||
|
||||
/// Warmup on a max size batch
|
||||
///
|
||||
/// Returns the maximum amount of tokens supported by the hardware
|
||||
#[instrument(skip_all)]
|
||||
pub async fn warmup(
|
||||
&mut self,
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let mut n_tokens = 0;
|
||||
let mut requests = Vec::new();
|
||||
// Create requests
|
||||
while n_tokens < max_prefill_tokens {
|
||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||
|
||||
let mut inputs = String::new();
|
||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||
if n_tokens == 0 {
|
||||
// 1 request is enough to test vision heads.
|
||||
// Sending images on other queries messes up easily with truncation.
|
||||
inputs.push_str(&format!(
|
||||
"",
|
||||
));
|
||||
}
|
||||
|
||||
requests.push(Request {
|
||||
id: 0,
|
||||
inputs,
|
||||
// We truncate the input on the server side to be sure that it has the correct size
|
||||
truncate,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
top_k: 10,
|
||||
top_p: 0.9,
|
||||
typical_p: 0.9,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.2,
|
||||
frequency_penalty: 0.1,
|
||||
watermark: true,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: max_total_tokens - truncate,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: true,
|
||||
}),
|
||||
prefill_logprobs: true,
|
||||
top_n_tokens: 20,
|
||||
});
|
||||
n_tokens += max_input_length;
|
||||
|
||||
// Check max_batch_size
|
||||
if Some(requests.len()) == max_batch_size {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let batch = Batch {
|
||||
id: 0,
|
||||
size: requests.len() as u32,
|
||||
requests,
|
||||
max_tokens: 0,
|
||||
};
|
||||
|
||||
let request = tonic::Request::new(WarmupRequest {
|
||||
batch: Some(batch),
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_total_tokens,
|
||||
})
|
||||
.inject_context();
|
||||
let response = self.stub.warmup(request).await?.into_inner();
|
||||
Ok(response.max_supported_total_tokens)
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given batch
|
||||
///
|
||||
/// Returns Generation for each request in batch
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
||||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||
let response = self.stub.prefill(request).await?.into_inner();
|
||||
Ok((
|
||||
response.generations,
|
||||
response.batch,
|
||||
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
|
||||
))
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given cached batches
|
||||
///
|
||||
/// Returns Generation for each request in batches
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batches: Vec<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||
let response = self.stub.decode(request).await?.into_inner();
|
||||
Ok((
|
||||
response.generations,
|
||||
response.batch,
|
||||
DecodeTimings::new(
|
||||
response.concat_ns,
|
||||
response.forward_ns,
|
||||
response.decode_ns,
|
||||
response.total_ns,
|
||||
),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PrefillTimings {
|
||||
pub forward: Duration,
|
||||
pub decode: Duration,
|
||||
pub total: Duration,
|
||||
}
|
||||
|
||||
impl PrefillTimings {
|
||||
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||
Self {
|
||||
forward: Duration::from_nanos(forward_ns),
|
||||
decode: Duration::from_nanos(decode_ns),
|
||||
total: Duration::from_nanos(total_ns),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DecodeTimings {
|
||||
pub concat: Option<Duration>,
|
||||
pub forward: Duration,
|
||||
pub decode: Duration,
|
||||
pub total: Duration,
|
||||
}
|
||||
|
||||
impl DecodeTimings {
|
||||
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||
Self {
|
||||
concat: concat_ns.map(Duration::from_nanos),
|
||||
forward: Duration::from_nanos(forward_ns),
|
||||
decode: Duration::from_nanos(decode_ns),
|
||||
total: Duration::from_nanos(total_ns),
|
||||
}
|
||||
}
|
||||
}
|
68
backends/v2/src/client/mod.rs
Normal file
68
backends/v2/src/client/mod.rs
Normal file
@ -0,0 +1,68 @@
|
||||
//! Text Generation gRPC client library
|
||||
|
||||
use async_trait::async_trait;
|
||||
use thiserror::Error;
|
||||
use tonic::transport;
|
||||
use tonic::Status;
|
||||
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
mod pb;
|
||||
|
||||
mod grpc_client;
|
||||
mod sharded_client;
|
||||
|
||||
pub use grpc_client::Client;
|
||||
pub use pb::generate::v2::{
|
||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, HealthResponse,
|
||||
InfoResponse, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Health {
|
||||
/// Check if a generate server is healthy by asking it to allocate a tensor on device
|
||||
async fn device_health(&self) -> Result<()>;
|
||||
|
||||
/// Check if a generate server is healthy by doing a forward pass.
|
||||
/// EXPENSIVE
|
||||
async fn model_health(&self) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ShardInfo {
|
||||
pub requires_padding: bool,
|
||||
pub dtype: String,
|
||||
pub device_type: String,
|
||||
pub window_size: Option<u32>,
|
||||
pub speculate: u32,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug, Clone)]
|
||||
pub enum ClientError {
|
||||
#[error("Could not connect to Text Generation server: {0}")]
|
||||
Connection(String),
|
||||
#[error("Server error: {0}")]
|
||||
Generation(String),
|
||||
#[error("Sharded results are empty")]
|
||||
EmptyResults,
|
||||
}
|
||||
|
||||
impl From<Status> for ClientError {
|
||||
fn from(err: Status) -> Self {
|
||||
let err = Self::Generation(err.message().to_string());
|
||||
tracing::error!("{err}");
|
||||
err
|
||||
}
|
||||
}
|
||||
|
||||
impl From<transport::Error> for ClientError {
|
||||
fn from(err: transport::Error) -> Self {
|
||||
let err = Self::Connection(err.to_string());
|
||||
tracing::error!("{err}");
|
||||
err
|
||||
}
|
||||
}
|
||||
|
||||
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ClientError>;
|
254
backends/v2/src/client/sharded_client.rs
Normal file
254
backends/v2/src/client/sharded_client.rs
Normal file
@ -0,0 +1,254 @@
|
||||
/// Multi shard Client
|
||||
use crate::client::{ClientError, Result};
|
||||
use crate::client::{Health, ShardInfo};
|
||||
|
||||
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
|
||||
use crate::client::InfoResponse;
|
||||
use crate::client::{
|
||||
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures::future::join_all;
|
||||
use tonic::transport::Uri;
|
||||
use tracing::instrument;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Text Generation Inference gRPC multi client
|
||||
pub struct ShardedClient {
|
||||
clients: Vec<Client>,
|
||||
}
|
||||
|
||||
impl ShardedClient {
|
||||
fn new(clients: Vec<Client>) -> Self {
|
||||
Self { clients }
|
||||
}
|
||||
|
||||
/// Create a new ShardedClient from a master client. The master client will communicate with
|
||||
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
|
||||
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
||||
// Get all uris/unix sockets from the master client
|
||||
let uris = master_client.service_discovery().await?;
|
||||
let futures = uris.into_iter().map(Client::connect_uds);
|
||||
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
|
||||
Ok(Self::new(clients?))
|
||||
}
|
||||
|
||||
/// Returns a client connected to the given uri
|
||||
#[allow(dead_code)]
|
||||
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||
let master_client = Client::connect(uri).await?;
|
||||
Self::from_master_client(master_client).await
|
||||
}
|
||||
|
||||
/// Returns a client connected to the given unix socket
|
||||
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||
let master_client = Client::connect_uds(path).await?;
|
||||
Self::from_master_client(master_client).await
|
||||
}
|
||||
|
||||
/// Get the model info
|
||||
#[instrument(skip(self))]
|
||||
pub async fn info(&mut self) -> Result<ShardInfo> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.info())
|
||||
.collect();
|
||||
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
|
||||
}
|
||||
|
||||
/// GRPC health check
|
||||
#[instrument(skip(self))]
|
||||
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.health())
|
||||
.collect();
|
||||
join_all(futures).await.pop().unwrap()
|
||||
}
|
||||
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.clear_cache(batch_id))
|
||||
.collect();
|
||||
join_all(futures).await.into_iter().collect()
|
||||
}
|
||||
|
||||
/// Filter a cached batch
|
||||
#[instrument(skip(self))]
|
||||
pub async fn filter_batch(
|
||||
&mut self,
|
||||
batch_id: u64,
|
||||
request_ids: Vec<u64>,
|
||||
) -> Result<Option<CachedBatch>> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
|
||||
.collect();
|
||||
// all shards return the same message
|
||||
join_all(futures).await.pop().unwrap()
|
||||
}
|
||||
|
||||
/// Warmup on a max size batch
|
||||
///
|
||||
/// Returns the maximum amount of tokens supported by the hardware
|
||||
#[instrument(skip(self))]
|
||||
pub async fn warmup(
|
||||
&mut self,
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| {
|
||||
Box::pin(client.warmup(
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_batch_size,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
// Take the minimum value
|
||||
let results = join_all(futures)
|
||||
.await
|
||||
.into_iter()
|
||||
.collect::<Result<Vec<Option<u32>>>>()?;
|
||||
Ok(results.into_iter().flatten().min())
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given batch
|
||||
///
|
||||
/// Returns Generation for each request in batch
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
|
||||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||
.collect();
|
||||
#[allow(clippy::type_complexity)]
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||
join_all(futures).await.into_iter().collect();
|
||||
let mut results = results?;
|
||||
|
||||
let (mut generations, next_batch, mut timings) =
|
||||
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||
|
||||
// Merge generations from different model shards
|
||||
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||
generations.append(&mut shard_generations);
|
||||
// Return the timings of the slowest shard
|
||||
if shard_timings.total > timings.total {
|
||||
timings = shard_timings;
|
||||
}
|
||||
}
|
||||
Ok((generations, next_batch, timings))
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given cached batches
|
||||
///
|
||||
/// Returns Generation for each request in batches
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batches: Vec<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||
.collect();
|
||||
#[allow(clippy::type_complexity)]
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
|
||||
join_all(futures).await.into_iter().collect();
|
||||
let mut results = results?;
|
||||
|
||||
let (mut generations, next_batch, mut timings) =
|
||||
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||
|
||||
// Merge generations from different model shards
|
||||
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||
generations.append(&mut shard_generations);
|
||||
// Return the timings of the slowest shard
|
||||
if shard_timings.total > timings.total {
|
||||
timings = shard_timings;
|
||||
}
|
||||
}
|
||||
Ok((generations, next_batch, timings))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<InfoResponse> for ShardInfo {
|
||||
fn from(value: InfoResponse) -> Self {
|
||||
Self {
|
||||
requires_padding: value.requires_padding,
|
||||
dtype: value.dtype,
|
||||
device_type: value.device_type,
|
||||
window_size: value.window_size,
|
||||
speculate: value.speculate,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Health for ShardedClient {
|
||||
async fn device_health(&self) -> Result<()> {
|
||||
self.clone().health().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn model_health(&self) -> Result<()> {
|
||||
// Dummy batch of 1 token and 1 generated token
|
||||
let liveness_request = Request {
|
||||
id: u64::MAX,
|
||||
inputs: "liveness".to_string(),
|
||||
truncate: 10,
|
||||
prefill_logprobs: false,
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 1.0,
|
||||
top_k: 0,
|
||||
top_p: 1.0,
|
||||
typical_p: 1.0,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.0,
|
||||
frequency_penalty: 0.0,
|
||||
watermark: false,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 1,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: false,
|
||||
}),
|
||||
top_n_tokens: 0,
|
||||
};
|
||||
let batch = Batch {
|
||||
id: u64::MAX,
|
||||
requests: vec![liveness_request],
|
||||
size: 1,
|
||||
max_tokens: 2,
|
||||
};
|
||||
self.clone().prefill(batch).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
144
backends/v2/src/lib.rs
Normal file
144
backends/v2/src/lib.rs
Normal file
@ -0,0 +1,144 @@
|
||||
mod backend;
|
||||
mod client;
|
||||
mod queue;
|
||||
|
||||
use crate::client::{ClientError, ShardedClient};
|
||||
pub(crate) use backend::BackendV2;
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
pub struct BackendInfo {
|
||||
/// Mandatory
|
||||
#[schema(example = "cuda")]
|
||||
pub model_device_type: String,
|
||||
#[schema(example = "torch.float16")]
|
||||
pub model_dtype: String,
|
||||
|
||||
/// Backend parameters
|
||||
#[schema(example = "1")]
|
||||
pub speculate: usize,
|
||||
#[schema(example = "1.2")]
|
||||
pub waiting_served_ratio: f32,
|
||||
#[schema(example = "32000")]
|
||||
pub max_batch_total_tokens: u32,
|
||||
#[schema(example = "20")]
|
||||
pub max_waiting_tokens: usize,
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub max_batch_size: Option<usize>,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn connect_backend(
|
||||
max_input_tokens: usize,
|
||||
max_total_tokens: usize,
|
||||
master_shard_uds_path: String,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<(BackendV2, BackendInfo), V2Error> {
|
||||
// Helper function
|
||||
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
||||
match max_supported_batch_total_tokens {
|
||||
// Older models do not support automatic max-batch-total-tokens
|
||||
None => {
|
||||
let max_batch_total_tokens = max_batch_total_tokens
|
||||
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
|
||||
tracing::warn!("Model does not support automatic max batch total tokens");
|
||||
Ok(max_batch_total_tokens)
|
||||
}
|
||||
// Flash attention models return their max supported total tokens
|
||||
Some(max_supported_batch_total_tokens) => {
|
||||
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
||||
if max_batch_total_tokens.is_some() {
|
||||
tracing::warn!(
|
||||
"`--max-batch-total-tokens` is deprecated for Flash \
|
||||
Attention models."
|
||||
);
|
||||
tracing::warn!(
|
||||
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
||||
);
|
||||
}
|
||||
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
||||
return Err(V2Error::NotEnoughMemory(max_total_tokens));
|
||||
}
|
||||
|
||||
Ok(max_supported_batch_total_tokens)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||
.await
|
||||
.map_err(V2Error::Connection)?;
|
||||
|
||||
// server is running on v2
|
||||
// Clear the cache; useful if the webserver rebooted
|
||||
sharded_client
|
||||
.clear_cache(None)
|
||||
.await
|
||||
.map_err(V2Error::Cache)?;
|
||||
// Get info from the shard
|
||||
let shard_info = sharded_client.info().await.map_err(V2Error::Info)?;
|
||||
|
||||
// Warmup model
|
||||
tracing::info!("Warming up model");
|
||||
let max_batch_total_tokens = check_max_batch_total_tokens(
|
||||
sharded_client
|
||||
.warmup(
|
||||
max_input_tokens as u32,
|
||||
max_batch_prefill_tokens,
|
||||
max_total_tokens as u32,
|
||||
max_batch_total_tokens.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))),
|
||||
max_batch_size,
|
||||
)
|
||||
.await
|
||||
.map_err(V2Error::Warmup)?,
|
||||
)?;
|
||||
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||
|
||||
let backend_info = BackendInfo {
|
||||
waiting_served_ratio,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
model_device_type: shard_info.device_type.clone(),
|
||||
model_dtype: shard_info.dtype.clone(),
|
||||
speculate: shard_info.speculate as usize,
|
||||
};
|
||||
|
||||
let backend = BackendV2::new(
|
||||
sharded_client,
|
||||
waiting_served_ratio,
|
||||
max_input_tokens as u32,
|
||||
max_total_tokens as u32,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
shard_info.requires_padding,
|
||||
shard_info.window_size,
|
||||
shard_info.speculate,
|
||||
);
|
||||
|
||||
tracing::info!("Using backend V3");
|
||||
|
||||
Ok((backend, backend_info))
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum V2Error {
|
||||
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||
Cache(ClientError),
|
||||
#[error("Unable to connect to the Python model shards: {0}")]
|
||||
Connection(ClientError),
|
||||
#[error("Unable to get the Python model shards info: {0}")]
|
||||
Info(ClientError),
|
||||
#[error("Unable to warmup the Python model shards: {0}")]
|
||||
Warmup(ClientError),
|
||||
#[error("Not enough memory to handle `max_total_tokens={0}`")]
|
||||
NotEnoughMemory(usize),
|
||||
}
|
212
backends/v2/src/main.rs
Normal file
212
backends/v2/src/main.rs
Normal file
@ -0,0 +1,212 @@
|
||||
use clap::{Parser, Subcommand};
|
||||
use text_generation_router::{server, usage_stats};
|
||||
use text_generation_router_v2::{connect_backend, V2Error};
|
||||
use thiserror::Error;
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[command(subcommand)]
|
||||
command: Option<Commands>,
|
||||
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
max_best_of: usize,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_top_n_tokens: u32,
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
max_input_tokens: usize,
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "1.2", long, env)]
|
||||
waiting_served_ratio: f32,
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_prefill_tokens: u32,
|
||||
#[clap(long, env)]
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
#[clap(long, env)]
|
||||
max_batch_size: Option<usize>,
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
hostname: String,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||
master_shard_uds_path: String,
|
||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||
tokenizer_name: String,
|
||||
#[clap(long, env)]
|
||||
tokenizer_config_path: Option<String>,
|
||||
#[clap(long, env)]
|
||||
revision: Option<String>,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
validation_workers: usize,
|
||||
#[clap(long, env)]
|
||||
api_key: Option<String>,
|
||||
#[clap(long, env)]
|
||||
json_output: bool,
|
||||
#[clap(long, env)]
|
||||
otlp_endpoint: Option<String>,
|
||||
#[clap(default_value = "text-generation-inference.router", long, env)]
|
||||
otlp_service_name: String,
|
||||
#[clap(long, env)]
|
||||
cors_allow_origin: Option<Vec<String>>,
|
||||
#[clap(long, env)]
|
||||
ngrok: bool,
|
||||
#[clap(long, env)]
|
||||
ngrok_authtoken: Option<String>,
|
||||
#[clap(long, env)]
|
||||
ngrok_edge: Option<String>,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
messages_api_enabled: bool,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
disable_grammar_support: bool,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
#[clap(default_value = "on", long, env)]
|
||||
usage_stats: usage_stats::UsageStatsLevel,
|
||||
}
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
enum Commands {
|
||||
PrintSchema,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), RouterError> {
|
||||
// Get args
|
||||
let args = Args::parse();
|
||||
// Pattern match configuration
|
||||
let Args {
|
||||
command,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
hostname,
|
||||
port,
|
||||
master_shard_uds_path,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
validation_workers,
|
||||
api_key,
|
||||
json_output,
|
||||
otlp_endpoint,
|
||||
otlp_service_name,
|
||||
cors_allow_origin,
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
usage_stats,
|
||||
} = args;
|
||||
|
||||
if let Some(Commands::PrintSchema) = command {
|
||||
use utoipa::OpenApi;
|
||||
let api_doc = text_generation_router::server::ApiDoc::openapi();
|
||||
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
|
||||
println!("{}", api_doc);
|
||||
std::process::exit(0);
|
||||
};
|
||||
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||
|
||||
// Validate args
|
||||
if max_input_tokens >= max_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||
}
|
||||
|
||||
if validation_workers == 0 {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`validation_workers` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(max_batch_size) = max_batch_size {
|
||||
if max_batch_size == 0 {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`max_batch_size` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let (backend, _backend_info) = connect_backend(
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
master_shard_uds_path,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Run server
|
||||
server::run(
|
||||
backend,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
validation_workers,
|
||||
api_key,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
hostname,
|
||||
port,
|
||||
cors_allow_origin,
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
usage_stats,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum RouterError {
|
||||
#[error("Argument validation error: {0}")]
|
||||
ArgumentValidation(String),
|
||||
#[error("Backend failed: {0}")]
|
||||
Backend(#[from] V2Error),
|
||||
#[error("WebServer error: {0}")]
|
||||
WebServer(#[from] server::WebServerError),
|
||||
#[error("Tokio runtime failed to start: {0}")]
|
||||
Tokio(#[from] std::io::Error),
|
||||
}
|
@ -1,15 +1,14 @@
|
||||
/// Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||
|
||||
use crate::infer::InferError;
|
||||
use crate::infer::InferStreamResponse;
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use crate::client::{
|
||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::min;
|
||||
use std::cmp::{Eq, Ord, PartialEq, PartialOrd};
|
||||
use std::collections::BinaryHeap;
|
||||
use std::env;
|
||||
use std::time::Duration;
|
||||
use text_generation_client::{Batch, Request};
|
||||
use std::collections::VecDeque;
|
||||
use text_generation_router::infer::InferError;
|
||||
use text_generation_router::infer::InferStreamResponse;
|
||||
use text_generation_router::validation::{
|
||||
ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||
};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{info_span, instrument, Span};
|
||||
@ -41,11 +40,11 @@ pub(crate) struct Queue {
|
||||
impl Queue {
|
||||
pub(crate) fn new(
|
||||
requires_padding: bool,
|
||||
max_input_length: u32,
|
||||
max_total_tokens: u32,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_input_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
) -> Self {
|
||||
// Create channel
|
||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||
@ -53,18 +52,17 @@ impl Queue {
|
||||
// Launch background queue task
|
||||
tokio::spawn(queue_task(
|
||||
requires_padding,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
queue_receiver,
|
||||
));
|
||||
|
||||
Self { queue_sender }
|
||||
}
|
||||
|
||||
/// Append an entry to the queue
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) fn append(&self, entry: Entry) {
|
||||
// Send append command to the background task managing the state
|
||||
@ -106,27 +104,27 @@ impl Queue {
|
||||
// Background task responsible of the queue state
|
||||
async fn queue_task(
|
||||
requires_padding: bool,
|
||||
max_input_length: u32,
|
||||
max_total_tokens: u32,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_input_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||
) {
|
||||
let mut state = State::new(
|
||||
requires_padding,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
block_size,
|
||||
window_size,
|
||||
speculate
|
||||
speculate,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
);
|
||||
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
match cmd {
|
||||
QueueCommand::Append(entry, span) => {
|
||||
span.in_scope(|| state.append(*entry));
|
||||
metrics::increment_gauge!("tgi_queue_size", 1.0);
|
||||
metrics::gauge!("tgi_queue_size").increment(1.0);
|
||||
}
|
||||
QueueCommand::NextBatch {
|
||||
min_size,
|
||||
@ -139,110 +137,17 @@ async fn queue_task(
|
||||
let next_batch =
|
||||
state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
|
||||
response_sender.send(next_batch).unwrap();
|
||||
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
|
||||
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct IdentifiableEntry(u64, Entry);
|
||||
|
||||
impl Eq for IdentifiableEntry {}
|
||||
|
||||
impl PartialEq for IdentifiableEntry {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.0 == other.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for IdentifiableEntry {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
let ordering = match self
|
||||
.1
|
||||
.request
|
||||
.input_length
|
||||
.cmp(&other.1.request.input_length)
|
||||
{
|
||||
std::cmp::Ordering::Equal => self.0.cmp(&other.0),
|
||||
any => any,
|
||||
};
|
||||
|
||||
// inverse to get min heap
|
||||
return ordering.reverse();
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for IdentifiableEntry {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct QueueImpl {
|
||||
regular_entries: BinaryHeap<IdentifiableEntry>,
|
||||
overdue_entries: BinaryHeap<IdentifiableEntry>,
|
||||
overdue_threshold: Duration,
|
||||
}
|
||||
|
||||
impl QueueImpl {
|
||||
fn new(capacity: usize, overdue_threshold: Duration) -> Self {
|
||||
Self {
|
||||
regular_entries: BinaryHeap::with_capacity(capacity),
|
||||
overdue_entries: BinaryHeap::with_capacity(capacity),
|
||||
overdue_threshold,
|
||||
}
|
||||
}
|
||||
|
||||
fn update(&mut self) {
|
||||
if self.regular_entries.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut left = BinaryHeap::with_capacity(self.regular_entries.capacity());
|
||||
|
||||
for entry in self.regular_entries.drain() {
|
||||
if entry.1.queue_time.elapsed() > self.overdue_threshold {
|
||||
self.overdue_entries.push(entry);
|
||||
} else {
|
||||
left.push(entry);
|
||||
}
|
||||
}
|
||||
|
||||
self.regular_entries = left;
|
||||
}
|
||||
|
||||
fn push(&mut self, entry: IdentifiableEntry) {
|
||||
if entry.1.queue_time.elapsed() > self.overdue_threshold {
|
||||
self.overdue_entries.push(entry);
|
||||
} else {
|
||||
self.regular_entries.push(entry);
|
||||
}
|
||||
}
|
||||
|
||||
fn pop(&mut self) -> Option<IdentifiableEntry> {
|
||||
if !self.overdue_entries.is_empty() {
|
||||
self.overdue_entries.pop()
|
||||
} else {
|
||||
self.regular_entries.pop()
|
||||
}
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
self.regular_entries.is_empty() && self.overdue_entries.is_empty()
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.regular_entries.len() + self.overdue_entries.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Queue State
|
||||
#[derive(Debug)]
|
||||
struct State {
|
||||
/// Queue entries
|
||||
entries: QueueImpl,
|
||||
/// Queue entries organized in a Vec
|
||||
entries: VecDeque<(u64, Entry)>,
|
||||
|
||||
/// Id of the next entry
|
||||
next_id: u64,
|
||||
@ -253,12 +158,6 @@ struct State {
|
||||
/// Whether the model is using padding
|
||||
requires_padding: bool,
|
||||
|
||||
/// Maximum input length, required for padding scenario
|
||||
max_input_length: u32,
|
||||
|
||||
/// Maximum input and output length, required for padding scenario
|
||||
max_total_tokens: u32,
|
||||
|
||||
/// Paged Attention block size
|
||||
block_size: u32,
|
||||
|
||||
@ -267,33 +166,33 @@ struct State {
|
||||
|
||||
/// Speculation amount
|
||||
speculate: u32,
|
||||
|
||||
/// max input tokens
|
||||
max_input_tokens: u32,
|
||||
|
||||
/// max total tokens,
|
||||
max_total_tokens: u32,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn new(
|
||||
requires_padding: bool,
|
||||
max_input_length: u32,
|
||||
max_total_tokens: u32,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_input_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
) -> Self {
|
||||
let default_threshold: u64 = 120;
|
||||
let threshold: u64 = match env::var("QUEUE_THRESHOLD_MS") {
|
||||
Ok(val) => val.parse().unwrap_or(default_threshold),
|
||||
Err(_) => default_threshold,
|
||||
};
|
||||
|
||||
Self {
|
||||
entries: QueueImpl::new(128, Duration::from_millis(threshold)),
|
||||
entries: VecDeque::with_capacity(128),
|
||||
next_id: 0,
|
||||
next_batch_id: 0,
|
||||
requires_padding,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
@ -304,7 +203,7 @@ impl State {
|
||||
entry.temp_span = Some(queue_span);
|
||||
|
||||
// Push entry in the queue
|
||||
self.entries.push(IdentifiableEntry(self.next_id, entry));
|
||||
self.entries.push_back((self.next_id, entry));
|
||||
self.next_id += 1;
|
||||
}
|
||||
|
||||
@ -329,11 +228,20 @@ impl State {
|
||||
}
|
||||
}
|
||||
|
||||
self.entries.update();
|
||||
if let Some(max_size) = max_size {
|
||||
if max_size == 0 {
|
||||
tracing::debug!("No capacity");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
// Pad prefill_token_budget to be a multiple of block size
|
||||
let prefill_token_budget =
|
||||
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
|
||||
|
||||
// Create span for this batch to add context to inference calls
|
||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||
next_batch_span.follows_from(&Span::current());
|
||||
next_batch_span.follows_from(Span::current());
|
||||
|
||||
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
||||
let mut batch_entries =
|
||||
@ -343,11 +251,11 @@ impl State {
|
||||
let mut decode_tokens: u32 = 0;
|
||||
|
||||
// Pop entries starting from the front of the queue
|
||||
while let Some(IdentifiableEntry(id, mut entry)) = self.entries.pop() {
|
||||
while let Some((id, mut entry)) = self.entries.pop_front() {
|
||||
// Filter entries where the response receiver was dropped (== entries where the request
|
||||
// was dropped by the client)
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
tracing::debug!("Dropping entry");
|
||||
continue;
|
||||
}
|
||||
@ -355,7 +263,7 @@ impl State {
|
||||
if self.requires_padding {
|
||||
// We pad to max input length in the Python shards
|
||||
// We need to take these padding tokens into the equation
|
||||
prefill_tokens = (batch_requests.len() + 1) as u32 * self.max_input_length;
|
||||
prefill_tokens = (batch_requests.len() + 1) as u32 * self.max_input_tokens
|
||||
} else {
|
||||
// pad to block size
|
||||
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
|
||||
@ -364,9 +272,7 @@ impl State {
|
||||
}
|
||||
|
||||
if self.requires_padding {
|
||||
// We pad to max total tokens in the Python shards
|
||||
// We need to take these padding tokens into the equation
|
||||
decode_tokens = (batch_requests.len() + 1) as u32 * (self.max_total_tokens - self.max_input_length);
|
||||
decode_tokens = (batch_requests.len() + 1) as u32 * (self.max_total_tokens - self.max_input_tokens);
|
||||
} else {
|
||||
let max_new_tokens = match self.window_size {
|
||||
None => entry.request.stopping_parameters.max_new_tokens,
|
||||
@ -387,7 +293,7 @@ impl State {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||
self.entries.push(IdentifiableEntry(id, entry));
|
||||
self.entries.push_front((id, entry));
|
||||
break;
|
||||
}
|
||||
|
||||
@ -403,10 +309,14 @@ impl State {
|
||||
batch_requests.push(Request {
|
||||
id,
|
||||
prefill_logprobs: entry.request.decoder_input_details,
|
||||
inputs: entry.request.inputs.clone(),
|
||||
inputs: entry.request.inputs.chunks_to_string(),
|
||||
truncate: entry.request.truncate,
|
||||
parameters: Some(entry.request.parameters.clone()),
|
||||
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
||||
parameters: Some(NextTokenChooserParameters::from(
|
||||
entry.request.parameters.clone(),
|
||||
)),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters::from(
|
||||
entry.request.stopping_parameters.clone(),
|
||||
)),
|
||||
top_n_tokens: entry.request.top_n_tokens,
|
||||
});
|
||||
// Set batch_time
|
||||
@ -422,7 +332,7 @@ impl State {
|
||||
|
||||
// Empty batch
|
||||
if batch_requests.is_empty() {
|
||||
tracing::debug!("Filterered out all entries");
|
||||
tracing::debug!("Filtered out all entries");
|
||||
return None;
|
||||
}
|
||||
|
||||
@ -434,7 +344,7 @@ impl State {
|
||||
for r in batch_requests.into_iter().rev() {
|
||||
let id = r.id;
|
||||
let entry = batch_entries.remove(&id).unwrap();
|
||||
self.entries.push(IdentifiableEntry(id, entry));
|
||||
self.entries.push_front((id, entry));
|
||||
}
|
||||
|
||||
return None;
|
||||
@ -454,7 +364,7 @@ impl State {
|
||||
// Increment batch id
|
||||
self.next_batch_id += 1;
|
||||
|
||||
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
|
||||
metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
|
||||
|
||||
Some((batch_entries, batch, next_batch_span))
|
||||
}
|
||||
@ -475,26 +385,49 @@ enum QueueCommand {
|
||||
},
|
||||
}
|
||||
|
||||
impl From<ValidParameters> for NextTokenChooserParameters {
|
||||
fn from(value: ValidParameters) -> Self {
|
||||
let (grammar, grammar_type) = match value.grammar {
|
||||
None => (String::new(), GrammarType::None),
|
||||
|
||||
Some(grammar) => match grammar {
|
||||
ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json),
|
||||
ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex),
|
||||
},
|
||||
};
|
||||
|
||||
Self {
|
||||
temperature: value.temperature,
|
||||
top_k: value.top_k,
|
||||
top_p: value.top_p,
|
||||
typical_p: value.typical_p,
|
||||
do_sample: value.do_sample,
|
||||
seed: value.seed,
|
||||
repetition_penalty: value.repetition_penalty,
|
||||
frequency_penalty: value.frequency_penalty,
|
||||
watermark: value.watermark,
|
||||
grammar,
|
||||
grammar_type: grammar_type.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
|
||||
fn from(value: ValidStoppingParameters) -> Self {
|
||||
Self {
|
||||
max_new_tokens: value.max_new_tokens,
|
||||
stop_sequences: value.stop_sequences,
|
||||
ignore_eos_token: value.ignore_eos_token,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use text_generation_client::{
|
||||
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use tracing::info_span;
|
||||
|
||||
fn default_queue() -> Queue {
|
||||
Queue::new(
|
||||
true, 1, 2, 1, None, 0
|
||||
)
|
||||
}
|
||||
|
||||
fn default_state() -> State {
|
||||
State::new(
|
||||
true, 1, 2, 1, None, 0
|
||||
)
|
||||
}
|
||||
|
||||
fn default_entry() -> (
|
||||
Entry,
|
||||
mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,
|
||||
@ -503,11 +436,13 @@ mod tests {
|
||||
|
||||
let entry = Entry {
|
||||
request: ValidGenerateRequest {
|
||||
inputs: String::new(),
|
||||
inputs: vec![],
|
||||
input_ids: Some(Arc::new(vec![])),
|
||||
input_length: 0,
|
||||
add_special_tokens: true,
|
||||
truncate: 0,
|
||||
decoder_input_details: false,
|
||||
parameters: NextTokenChooserParameters {
|
||||
parameters: ValidParameters {
|
||||
temperature: 0.0,
|
||||
top_k: 0,
|
||||
top_p: 0.0,
|
||||
@ -517,15 +452,15 @@ mod tests {
|
||||
repetition_penalty: 0.0,
|
||||
frequency_penalty: 0.0,
|
||||
watermark: false,
|
||||
grammar: String::new(),
|
||||
grammar_type: ProtoGrammarType::None as i32,
|
||||
grammar: None,
|
||||
},
|
||||
stopping_parameters: StoppingCriteriaParameters {
|
||||
stopping_parameters: ValidStoppingParameters {
|
||||
ignore_eos_token: false,
|
||||
max_new_tokens: 1,
|
||||
stop_sequences: vec![],
|
||||
},
|
||||
top_n_tokens: 0,
|
||||
adapter_id: None,
|
||||
},
|
||||
response_tx,
|
||||
span: info_span!("entry"),
|
||||
@ -538,7 +473,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_append() {
|
||||
let mut state = default_state();
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
let (entry, _guard) = default_entry();
|
||||
|
||||
assert_eq!(state.next_id, 0);
|
||||
@ -548,13 +483,13 @@ mod tests {
|
||||
|
||||
assert_eq!(state.next_id, 1);
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
let id = state.entries.pop().unwrap().0;
|
||||
let (id, _) = state.entries.remove(0).unwrap();
|
||||
assert_eq!(id, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_empty() {
|
||||
let mut state = default_state();
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
|
||||
assert!(state.next_batch(None, None, 1, 1).is_none());
|
||||
assert!(state.next_batch(Some(1), None, 1, 1).is_none());
|
||||
@ -562,13 +497,13 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_min_size() {
|
||||
let mut state = default_state();
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, None, 2, 4).unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.contains_key(&1));
|
||||
@ -588,13 +523,13 @@ mod tests {
|
||||
|
||||
assert_eq!(state.next_id, 3);
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
let IdentifiableEntry(id, _) = state.entries.pop().unwrap();
|
||||
let (id, _) = state.entries.remove(0).unwrap();
|
||||
assert_eq!(id, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_max_size() {
|
||||
let mut state = default_state();
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
@ -614,13 +549,13 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_token_budget() {
|
||||
let mut state = default_state();
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, None, 1, 2).unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert_eq!(batch.id, 0);
|
||||
@ -633,7 +568,7 @@ mod tests {
|
||||
let (entry3, _guard3) = default_entry();
|
||||
state.append(entry3);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, None, 3, 6).unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.contains_key(&2));
|
||||
@ -647,14 +582,14 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_append() {
|
||||
let queue = default_queue();
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let (entry, _guard) = default_entry();
|
||||
queue.append(entry);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_empty() {
|
||||
let queue = default_queue();
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
|
||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||
@ -662,13 +597,13 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_min_size() {
|
||||
let queue = default_queue();
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
queue.append(entry2);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, None, 2, 4).await.unwrap();
|
||||
let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.contains_key(&1));
|
||||
@ -685,7 +620,7 @@ mod tests {
|
||||
// Not enough token budget
|
||||
assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none());
|
||||
// Ok
|
||||
let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 4).await.unwrap();
|
||||
let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap();
|
||||
assert_eq!(entries2.len(), 1);
|
||||
assert!(entries2.contains_key(&2));
|
||||
assert!(entries2.get(&2).unwrap().batch_time.is_some());
|
||||
@ -695,7 +630,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_max_size() {
|
||||
let queue = default_queue();
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -711,13 +646,13 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_budget() {
|
||||
let queue = default_queue();
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
queue.append(entry2);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, None, 1, 2).await.unwrap();
|
||||
let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert_eq!(batch.id, 0);
|
||||
@ -726,7 +661,7 @@ mod tests {
|
||||
let (entry3, _guard3) = default_entry();
|
||||
queue.append(entry3);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, None, 3, 6).await.unwrap();
|
||||
let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.contains_key(&2));
|
||||
@ -736,7 +671,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_speculate() {
|
||||
let queue = Queue::new(true, 1, 2, 1, None, 2);
|
||||
let queue = Queue::new(false, 1, None, 2);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -755,7 +690,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_dropped_receiver() {
|
||||
let queue = default_queue();
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let (entry, _) = default_entry();
|
||||
queue.append(entry);
|
||||
|
83
backends/v3/Cargo.toml
Normal file
83
backends/v3/Cargo.toml
Normal file
@ -0,0 +1,83 @@
|
||||
[package]
|
||||
name = "text-generation-router-v3"
|
||||
description = "Text Generation Webserver"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "text-generation-router"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.74"
|
||||
async-stream = "0.3.5"
|
||||
axum = { version = "0.7", features = ["json"] }
|
||||
axum-tracing-opentelemetry = "0.16"
|
||||
text-generation-router = { path = "../../router" }
|
||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||
grpc-metadata = { path = "../grpc-metadata" }
|
||||
futures = "0.3.28"
|
||||
hf-hub = { workspace = true }
|
||||
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||
metrics = { workspace = true }
|
||||
metrics-exporter-prometheus = { workspace = true }
|
||||
nohash-hasher = "0.2.0"
|
||||
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
||||
opentelemetry-otlp = "0.13.0"
|
||||
rand = "0.8.5"
|
||||
reqwest = { version = "0.11.20", features = [] }
|
||||
serde = "1.0.188"
|
||||
serde_json = "1.0.107"
|
||||
slotmap = "1.0.7"
|
||||
thiserror = "1.0.48"
|
||||
tokenizers = { workspace = true }
|
||||
tokio = { version = "1.32.0", features = [
|
||||
"rt",
|
||||
"rt-multi-thread",
|
||||
"parking_lot",
|
||||
"signal",
|
||||
"sync",
|
||||
] }
|
||||
tokio-stream = "0.1.14"
|
||||
tower-http = { version = "0.5.1", features = ["cors"] }
|
||||
tracing = "0.1.37"
|
||||
tracing-opentelemetry = "0.21.0"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
||||
"opentelemetry-otlp",
|
||||
] }
|
||||
minijinja = { workspace = true }
|
||||
minijinja-contrib = { workspace = true }
|
||||
futures-util = "0.3.30"
|
||||
regex = "1.10.3"
|
||||
once_cell = "1.19.0"
|
||||
image = "0.25.1"
|
||||
base64 = { workspace = true }
|
||||
prost = "^0.12"
|
||||
tonic = "^0.10"
|
||||
tower = "^0.4"
|
||||
|
||||
[build-dependencies]
|
||||
tonic-build = "0.10.1"
|
||||
prost-build = "0.12.1"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.3"
|
||||
itertools = "0.13"
|
||||
|
||||
[features]
|
||||
default = ["ngrok"]
|
||||
ngrok = ["text-generation-router/ngrok"]
|
||||
google = ["text-generation-router/google"]
|
||||
kserve = ["text-generation-router/kserve"]
|
||||
|
||||
[[bench]]
|
||||
name = "prefix_cache"
|
||||
harness = false
|
47
backends/v3/benches/prefix_cache.rs
Normal file
47
backends/v3/benches/prefix_cache.rs
Normal file
@ -0,0 +1,47 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
use rand::Rng;
|
||||
|
||||
use text_generation_router_v3::block_allocator::Allocator;
|
||||
use text_generation_router_v3::radix::RadixAllocator;
|
||||
|
||||
fn prefix_cache_benchmark(c: &mut Criterion) {
|
||||
// let prefixes: Vec<Vec<u32>> = (0..8192)
|
||||
// .chunks(256)
|
||||
// .into_iter()
|
||||
// .map(|c| c.collect())
|
||||
// .collect();
|
||||
|
||||
let mut cache = RadixAllocator::new(1, 262144, None);
|
||||
|
||||
c.bench_function("Radix allocator", |b| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
//prefixes
|
||||
// .choose_multiple(&mut rand::thread_rng(), 5)
|
||||
// .fold(Vec::new(), |mut v, s| {
|
||||
// v.extend(s);
|
||||
// v
|
||||
// })
|
||||
|
||||
(0..7936)
|
||||
.map(|_| rand::thread_rng().gen_range(0..1024))
|
||||
.collect::<Vec<u32>>()
|
||||
},
|
||||
|prefill| {
|
||||
let alloc = cache.allocate(
|
||||
prefill.len() as u32 + 13,
|
||||
Some(Arc::new(black_box(prefill))),
|
||||
);
|
||||
if let Some(alloc) = alloc {
|
||||
cache.free(alloc.blocks.clone(), alloc.allocation_id);
|
||||
}
|
||||
},
|
||||
criterion::BatchSize::SmallInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(benches, prefix_cache_benchmark);
|
||||
criterion_main!(benches);
|
19
backends/v3/build.rs
Normal file
19
backends/v3/build.rs
Normal file
@ -0,0 +1,19 @@
|
||||
use std::fs;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("cargo:rerun-if-changed=../../proto/");
|
||||
|
||||
fs::create_dir_all("src/client/pb").unwrap_or(());
|
||||
let mut config = prost_build::Config::new();
|
||||
config.protoc_arg("--experimental_allow_proto3_optional");
|
||||
|
||||
tonic_build::configure()
|
||||
.build_client(true)
|
||||
.build_server(false)
|
||||
.out_dir("src/client/pb")
|
||||
.include_file("mod.rs")
|
||||
.compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"])
|
||||
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||
|
||||
Ok(())
|
||||
}
|
518
backends/v3/src/backend.rs
Normal file
518
backends/v3/src/backend.rs
Normal file
@ -0,0 +1,518 @@
|
||||
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
|
||||
/// Batching and inference logic
|
||||
use crate::queue::{Entry, Queue};
|
||||
use async_trait::async_trait;
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::Arc;
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
|
||||
pub struct BackendV3 {
|
||||
/// Request queue
|
||||
queue: Queue,
|
||||
/// Notify batcher on queue appends
|
||||
batching_task_notifier: Arc<Notify>,
|
||||
/// Client clone, used for health checks to skip the queue
|
||||
client: ShardedClient,
|
||||
}
|
||||
|
||||
impl BackendV3 {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn new(
|
||||
client: ShardedClient,
|
||||
waiting_served_ratio: f32,
|
||||
max_input_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
requires_padding: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
) -> Self {
|
||||
let prefix_caching =
|
||||
std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var");
|
||||
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
|
||||
let attention: String = std::env::var("ATTENTION").expect("attention env var");
|
||||
|
||||
let attention: Attention = attention
|
||||
.parse()
|
||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
|
||||
let block_size = attention.block_size();
|
||||
|
||||
let queue = Queue::new(
|
||||
requires_padding,
|
||||
block_size,
|
||||
prefix_caching,
|
||||
window_size,
|
||||
speculate,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_total_tokens,
|
||||
);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
// Spawn batching background task that contains all the inference logic
|
||||
tokio::spawn(batching_task(
|
||||
client.clone(),
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
queue.clone(),
|
||||
batching_task_notifier.clone(),
|
||||
));
|
||||
|
||||
Self {
|
||||
queue,
|
||||
batching_task_notifier,
|
||||
client,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Backend for BackendV3 {
|
||||
#[instrument(skip_all)]
|
||||
fn schedule(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||
// MPSC channel to communicate with the background batching task
|
||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||
|
||||
// Append the request to the queue
|
||||
self.queue.append(Entry {
|
||||
request,
|
||||
response_tx,
|
||||
span: Span::current(),
|
||||
temp_span: None,
|
||||
queue_time: Instant::now(),
|
||||
batch_time: None,
|
||||
block_allocation: None,
|
||||
});
|
||||
|
||||
// Notify the background task that we have a new entry in the queue that needs
|
||||
// to be batched
|
||||
self.batching_task_notifier.notify_one();
|
||||
|
||||
// Return stream
|
||||
Ok(UnboundedReceiverStream::new(response_rx))
|
||||
}
|
||||
|
||||
async fn health(&self, current_health: bool) -> bool {
|
||||
if current_health {
|
||||
// Generation is healthy, we only check that the shards can allocate on device
|
||||
self.client.device_health().await
|
||||
} else {
|
||||
self.client.model_health().await
|
||||
}
|
||||
.is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
/// Batching logic
|
||||
/// Will be launched in a background Tokio task
|
||||
///
|
||||
/// Batches requests and sends them to the inference server
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn batching_task(
|
||||
mut client: ShardedClient,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
queue: Queue,
|
||||
notifier: Arc<Notify>,
|
||||
) {
|
||||
// Infinite loop
|
||||
loop {
|
||||
// Wait for a notification from the Infer struct
|
||||
notifier.notified().await;
|
||||
|
||||
// Get the next batch from the queue
|
||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||
// waiting in the queue
|
||||
while let Some((mut entries, batch, span)) = queue
|
||||
.next_batch(
|
||||
None,
|
||||
max_batch_size,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
let mut waiting_tokens = 1;
|
||||
|
||||
// We loop until we do not receive any cached batch from the inference server (== until
|
||||
// all requests have met their stopping criteria)
|
||||
while let Some(batch) = cached_batch {
|
||||
// Get current batch info
|
||||
let batch_size = batch.size;
|
||||
let batch_max_tokens = batch.max_tokens;
|
||||
let mut batches = vec![batch];
|
||||
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||
|
||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||
// to add a new batch even though its size might be small
|
||||
None
|
||||
} else {
|
||||
// Minimum batch size
|
||||
// TODO: temporarily disable to avoid incorrect deallocation +
|
||||
// reallocation when using prefix caching.
|
||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||
};
|
||||
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
let max_size =
|
||||
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
||||
|
||||
// Try to get a new batch
|
||||
if let Some((mut new_entries, new_batch, span)) = queue
|
||||
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
||||
.await
|
||||
{
|
||||
// Tracking metrics
|
||||
if min_size.is_some() {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||
.increment(1);
|
||||
} else {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to add the info that this entry is waiting
|
||||
// because a new batch is being computed
|
||||
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
|
||||
// Add relationships
|
||||
span.follows_from(&entry_waiting_span);
|
||||
entry_waiting_span.follows_from(&span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_waiting_span);
|
||||
});
|
||||
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
// Reset waiting counter
|
||||
waiting_tokens = 1;
|
||||
// Extend current batch with the new batch
|
||||
if let Some(new_cached_batch) = new_cached_batch {
|
||||
entries.extend(new_entries);
|
||||
batches.push(new_cached_batch);
|
||||
}
|
||||
}
|
||||
|
||||
// Create span for this batch to add context to inference calls
|
||||
let next_batch_size = entries.len();
|
||||
let next_batch_span =
|
||||
info_span!(parent: None, "batch", batch_size = next_batch_size);
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
// Add relationships
|
||||
next_batch_span.follows_from(&entry_batch_span);
|
||||
entry_batch_span.follows_from(&next_batch_span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
});
|
||||
|
||||
cached_batch = decode(&mut client, batches, &mut entries)
|
||||
.instrument(next_batch_span)
|
||||
.await;
|
||||
waiting_tokens += 1;
|
||||
}
|
||||
metrics::gauge!("tgi_batch_current_size").set(0.0);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn prefill(
|
||||
client: &mut ShardedClient,
|
||||
batch: Batch,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_id = batch.id;
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
||||
|
||||
match client.prefill(batch).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||
.record(timings.decode.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
|
||||
.record(start_time.elapsed().as_secs_f64());
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
let _ = client.clear_cache(Some(batch_id)).await;
|
||||
send_errors(err, entries);
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn decode(
|
||||
client: &mut ShardedClient,
|
||||
batches: Vec<CachedBatch>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
||||
|
||||
match client.decode(batches).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
if let Some(concat_duration) = timings.concat {
|
||||
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||
.record(concat_duration.as_secs_f64());
|
||||
}
|
||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
|
||||
.record(timings.decode.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
|
||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
|
||||
.record(start_time.elapsed().as_secs_f64());
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
for id in batch_ids {
|
||||
let _ = client.clear_cache(Some(id)).await;
|
||||
}
|
||||
send_errors(err, entries);
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Filter a `batch` and remove all requests not present in `entries`
|
||||
#[instrument(skip_all)]
|
||||
async fn filter_batch(
|
||||
client: &mut ShardedClient,
|
||||
next_batch: Option<CachedBatch>,
|
||||
entries: &IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
let mut batch = next_batch?;
|
||||
|
||||
// No need to filter
|
||||
if batch.size as usize == entries.len() {
|
||||
return Some(batch);
|
||||
}
|
||||
|
||||
let id = batch.id;
|
||||
|
||||
// Retain only requests that are still in entries
|
||||
batch.request_ids.retain(|id| entries.contains_key(id));
|
||||
|
||||
if batch.request_ids.is_empty() {
|
||||
// All requests have been filtered out
|
||||
// Next batch is now empty
|
||||
// Clear it from the Python shards cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client.clear_cache(Some(id)).await.unwrap();
|
||||
None
|
||||
} else {
|
||||
// Filter Python shard cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client.filter_batch(id, batch.request_ids).await.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||
/// and filter entries
|
||||
#[instrument(skip_all)]
|
||||
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
||||
generations.into_iter().for_each(|generation| {
|
||||
let id = generation.request_id;
|
||||
// Get entry
|
||||
// We can `expect` here as the request id should always be in the entries
|
||||
let entry = entries
|
||||
.get(&id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||
// Send generation responses back to the infer task
|
||||
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||
// request and we need to stop generating hence why we unwrap_or(true)
|
||||
let stopped = send_responses(generation, entry).inspect_err(|_err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
}).unwrap_or(true);
|
||||
if stopped {
|
||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Send responses through the `entry` response channel
|
||||
fn send_responses(
|
||||
generation: Generation,
|
||||
entry: &Entry,
|
||||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||
// Return directly if the channel is disconnected
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
let mut stopped = false;
|
||||
|
||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||
// Create Token objects
|
||||
// We do that here instead of in the Python code as Rust for loops are faster
|
||||
let prefill_tokens = prefill_tokens
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(prefill_tokens.logprobs)
|
||||
.zip(prefill_tokens.texts)
|
||||
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||
.collect();
|
||||
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
||||
}
|
||||
|
||||
// Create last Token
|
||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||
let n = tokens_.ids.len();
|
||||
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
||||
let mut iterator = tokens_
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(tokens_.logprobs)
|
||||
.zip(tokens_.texts)
|
||||
.zip(tokens_.is_special)
|
||||
.enumerate()
|
||||
.peekable();
|
||||
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
|
||||
let token = Token {
|
||||
id,
|
||||
text,
|
||||
logprob,
|
||||
special,
|
||||
};
|
||||
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
|
||||
top_tokens_
|
||||
.ids
|
||||
.iter()
|
||||
.zip(top_tokens_.logprobs.iter())
|
||||
.zip(top_tokens_.texts.iter())
|
||||
.zip(top_tokens_.is_special.iter())
|
||||
.map(|(((&id, &logprob), text), &special)| Token {
|
||||
id,
|
||||
text: text.to_string(),
|
||||
logprob,
|
||||
special,
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
match (&generation.generated_text, iterator.peek()) {
|
||||
(Some(generated_text), None) => {
|
||||
// Generation has ended
|
||||
stopped = true;
|
||||
// Send message
|
||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text: GeneratedText::from(generated_text.clone()),
|
||||
queued: entry.queue_time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
}))?;
|
||||
}
|
||||
_ => {
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(stopped)
|
||||
}
|
||||
|
||||
/// Send errors to Infer for all `entries`
|
||||
#[instrument(skip_all)]
|
||||
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||
entries.drain().for_each(|(_, entry)| {
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||
let err = InferError::GenerationError(error.to_string());
|
||||
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
entry
|
||||
.response_tx
|
||||
.send(Err(err))
|
||||
.unwrap_or(());
|
||||
});
|
||||
}
|
||||
|
||||
impl From<crate::client::GeneratedText> for GeneratedText {
|
||||
fn from(value: crate::client::GeneratedText) -> Self {
|
||||
let v3_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();
|
||||
let finish_reason = match v3_finish_reason {
|
||||
crate::client::FinishReason::Length => FinishReason::Length,
|
||||
crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||
crate::client::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||
};
|
||||
|
||||
Self {
|
||||
text: value.text,
|
||||
generated_tokens: value.generated_tokens,
|
||||
finish_reason,
|
||||
seed: value.seed,
|
||||
}
|
||||
}
|
||||
}
|
209
backends/v3/src/block_allocator.rs
Normal file
209
backends/v3/src/block_allocator.rs
Normal file
@ -0,0 +1,209 @@
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
use crate::radix::RadixAllocator;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BlockAllocation {
|
||||
pub allocation_id: u64,
|
||||
pub blocks: Vec<u32>,
|
||||
pub slots: Vec<u32>,
|
||||
|
||||
/// Prefix that was cached and for which the KV does not have to
|
||||
/// be recomputed.
|
||||
pub prefix_len: u32,
|
||||
|
||||
pub(crate) block_allocator: Option<BlockAllocator>,
|
||||
}
|
||||
|
||||
impl Drop for BlockAllocation {
|
||||
fn drop(&mut self) {
|
||||
if let Some(block_allocator) = self.block_allocator.as_mut() {
|
||||
block_allocator.free(self.blocks.clone(), self.allocation_id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BlockAllocator {
|
||||
/// Channel to communicate with the background task
|
||||
block_allocator: mpsc::UnboundedSender<BlockAllocatorCommand>,
|
||||
}
|
||||
|
||||
impl BlockAllocator {
|
||||
pub(crate) fn new(
|
||||
max_batch_total_tokens: u32,
|
||||
block_size: u32,
|
||||
prefix_caching: bool,
|
||||
window_size: Option<u32>,
|
||||
) -> Self {
|
||||
// Create channel
|
||||
let (sender, receiver) = mpsc::unbounded_channel();
|
||||
|
||||
// Launch background queue task
|
||||
tokio::spawn(block_allocator_task(
|
||||
max_batch_total_tokens / block_size,
|
||||
block_size,
|
||||
prefix_caching,
|
||||
window_size,
|
||||
receiver,
|
||||
));
|
||||
|
||||
Self {
|
||||
block_allocator: sender,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn allocate(
|
||||
&self,
|
||||
tokens: u32,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation> {
|
||||
let (response_sender, response_receiver) = oneshot::channel();
|
||||
self.block_allocator
|
||||
.send(BlockAllocatorCommand::Allocate {
|
||||
tokens,
|
||||
prefill_tokens,
|
||||
response_sender,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
response_receiver.await.unwrap().map(|mut allocation| {
|
||||
allocation.block_allocator = Some(self.clone());
|
||||
allocation
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn free(&self, blocks: Vec<u32>, allocation_id: u64) {
|
||||
self.block_allocator
|
||||
.send(BlockAllocatorCommand::Free {
|
||||
allocation_id,
|
||||
blocks,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
async fn block_allocator_task(
|
||||
blocks: u32,
|
||||
block_size: u32,
|
||||
prefix_caching: bool,
|
||||
window_size: Option<u32>,
|
||||
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
|
||||
) {
|
||||
let mut allocator: Box<dyn Allocator + Send> = if prefix_caching {
|
||||
Box::new(RadixAllocator::new(block_size, blocks, window_size))
|
||||
} else {
|
||||
Box::new(SimpleAllocator::new(blocks, block_size, window_size))
|
||||
};
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
match cmd {
|
||||
BlockAllocatorCommand::Free {
|
||||
blocks,
|
||||
allocation_id,
|
||||
} => allocator.free(blocks, allocation_id),
|
||||
BlockAllocatorCommand::Allocate {
|
||||
tokens,
|
||||
prefill_tokens,
|
||||
response_sender,
|
||||
} => {
|
||||
response_sender
|
||||
.send(allocator.allocate(tokens, prefill_tokens))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum BlockAllocatorCommand {
|
||||
Free {
|
||||
blocks: Vec<u32>,
|
||||
allocation_id: u64,
|
||||
},
|
||||
Allocate {
|
||||
tokens: u32,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
response_sender: oneshot::Sender<Option<BlockAllocation>>,
|
||||
},
|
||||
}
|
||||
|
||||
pub trait Allocator {
|
||||
fn allocate(
|
||||
&mut self,
|
||||
tokens: u32,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation>;
|
||||
|
||||
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
|
||||
}
|
||||
pub struct SimpleAllocator {
|
||||
free_blocks: Vec<u32>,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
}
|
||||
|
||||
impl SimpleAllocator {
|
||||
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
|
||||
SimpleAllocator {
|
||||
block_size,
|
||||
// Block 0 is reserved for health checks
|
||||
free_blocks: (1..blocks).collect(),
|
||||
window_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Allocator for SimpleAllocator {
|
||||
fn allocate(
|
||||
&mut self,
|
||||
tokens: u32,
|
||||
_prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation> {
|
||||
// Apply window size
|
||||
let (required_blocks, repeats) = {
|
||||
let (tokens, repeats) = match self.window_size {
|
||||
None => (tokens, 1),
|
||||
Some(window_size) => {
|
||||
let repeats = (tokens + window_size - 1) / window_size;
|
||||
let tokens = core::cmp::min(tokens, window_size);
|
||||
(tokens, repeats as usize)
|
||||
}
|
||||
};
|
||||
// Pad to a multiple of block size
|
||||
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
|
||||
(required_blocks, repeats)
|
||||
};
|
||||
|
||||
let tokens = tokens as usize;
|
||||
if required_blocks > self.free_blocks.len() as u32 {
|
||||
None
|
||||
} else {
|
||||
let blocks = self
|
||||
.free_blocks
|
||||
.split_off(self.free_blocks.len() - required_blocks as usize);
|
||||
let mut slots =
|
||||
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
|
||||
|
||||
'slots: for block_id in blocks.repeat(repeats).iter() {
|
||||
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||
slots.push(s);
|
||||
if slots.len() == tokens {
|
||||
break 'slots;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(BlockAllocation {
|
||||
allocation_id: 0,
|
||||
blocks,
|
||||
slots,
|
||||
prefix_len: 0,
|
||||
block_allocator: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {
|
||||
self.free_blocks.extend(blocks)
|
||||
}
|
||||
}
|
288
backends/v3/src/client/grpc_client.rs
Normal file
288
backends/v3/src/client/grpc_client.rs
Normal file
@ -0,0 +1,288 @@
|
||||
/// Single shard Client
|
||||
use crate::client::{pb, Chunk};
|
||||
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
|
||||
use base64::engine::general_purpose::STANDARD;
|
||||
use base64::Engine;
|
||||
use grpc_metadata::InjectTelemetryContext;
|
||||
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
|
||||
use pb::generate::v3::*;
|
||||
use std::cmp::min;
|
||||
use std::time::Duration;
|
||||
use tonic::transport::{Channel, Uri};
|
||||
use tracing::instrument;
|
||||
|
||||
/// Text Generation Inference gRPC client
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Client {
|
||||
stub: TextGenerationServiceClient<Channel>,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
/// Returns a client connected to the given url
|
||||
#[allow(dead_code)]
|
||||
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||
let channel = Channel::builder(uri).connect().await?;
|
||||
|
||||
Ok(Self {
|
||||
stub: TextGenerationServiceClient::new(channel),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a client connected to the given unix socket
|
||||
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||
let channel = Channel::from_shared("http://[::]:50051".to_string())
|
||||
.unwrap()
|
||||
.connect_with_connector(tower::service_fn(move |_: Uri| {
|
||||
tokio::net::UnixStream::connect(path.clone())
|
||||
}))
|
||||
.await?;
|
||||
|
||||
Ok(Self {
|
||||
stub: TextGenerationServiceClient::new(channel),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a list of uris or unix sockets of all shards
|
||||
#[instrument(skip(self))]
|
||||
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
||||
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
|
||||
let response = self.stub.service_discovery(request).await.map_err(|_| {
|
||||
ClientError::Connection("Server does not support v3 interface".to_string())
|
||||
})?;
|
||||
let urls = response
|
||||
.into_inner()
|
||||
.urls
|
||||
.into_iter()
|
||||
// Remove unix socket prefix
|
||||
.map(|url| match url.strip_prefix("unix://") {
|
||||
None => url,
|
||||
Some(stripped_url) => stripped_url.to_string(),
|
||||
})
|
||||
.collect();
|
||||
Ok(urls)
|
||||
}
|
||||
|
||||
/// Get model info
|
||||
#[instrument(skip(self))]
|
||||
pub async fn info(&mut self) -> Result<InfoResponse> {
|
||||
let request = tonic::Request::new(InfoRequest {}).inject_context();
|
||||
let response = self.stub.info(request).await?.into_inner();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Get model health
|
||||
#[instrument(skip(self))]
|
||||
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||
let request = tonic::Request::new(HealthRequest {}).inject_context();
|
||||
let response = self.stub.health(request).await?.into_inner();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
|
||||
self.stub.clear_cache(request).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Filter a cached batch
|
||||
#[instrument(skip(self))]
|
||||
pub async fn filter_batch(
|
||||
&mut self,
|
||||
batch_id: u64,
|
||||
request_ids: Vec<u64>,
|
||||
) -> Result<Option<CachedBatch>> {
|
||||
let request = tonic::Request::new(FilterBatchRequest {
|
||||
batch_id,
|
||||
request_ids,
|
||||
})
|
||||
.inject_context();
|
||||
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
||||
Ok(filtered_batch.batch)
|
||||
}
|
||||
|
||||
/// Warmup on a max size batch
|
||||
///
|
||||
/// Returns the maximum amount of tokens supported by the hardware
|
||||
#[instrument(skip_all)]
|
||||
pub async fn warmup(
|
||||
&mut self,
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let mut n_tokens = 0;
|
||||
let mut requests = Vec::new();
|
||||
// Create requests
|
||||
while n_tokens < max_prefill_tokens {
|
||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||
|
||||
let mut input_chunks = Vec::new();
|
||||
input_chunks
|
||||
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
|
||||
if n_tokens == 0 {
|
||||
input_chunks.push(
|
||||
Chunk::Image(Image {
|
||||
// Safe unwrap, because we control the data.
|
||||
data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(),
|
||||
mimetype: "image/jpeg;base64".to_string(),
|
||||
})
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
|
||||
// Send stringly-typed inputs for compatibility for backends that haven't
|
||||
// been updated to support chunks.
|
||||
|
||||
let mut inputs = String::new();
|
||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||
if n_tokens == 0 {
|
||||
// 1 request is enough to test vision heads.
|
||||
// Sending images on other queries messes up easily with truncation.
|
||||
inputs.push_str(&format!(
|
||||
"",
|
||||
));
|
||||
}
|
||||
|
||||
requests.push(Request {
|
||||
id: 0,
|
||||
inputs,
|
||||
add_special_tokens: true,
|
||||
input_chunks: Some(Input {
|
||||
chunks: input_chunks,
|
||||
}),
|
||||
// We truncate the input on the server side to be sure that it has the correct size
|
||||
truncate,
|
||||
// Blocks and slots will be set on the server side if we use paged attention
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
prefix_len: 0,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
top_k: 10,
|
||||
top_p: 0.9,
|
||||
typical_p: 0.9,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.2,
|
||||
frequency_penalty: 0.1,
|
||||
watermark: true,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: max_total_tokens - truncate,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: true,
|
||||
}),
|
||||
prefill_logprobs: true,
|
||||
top_n_tokens: 20,
|
||||
adapter_id: None,
|
||||
});
|
||||
n_tokens += max_input_length;
|
||||
|
||||
// Check max_batch_size
|
||||
if Some(requests.len()) == max_batch_size {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let batch = Batch {
|
||||
id: 0,
|
||||
size: requests.len() as u32,
|
||||
requests,
|
||||
max_tokens: max_input_length,
|
||||
max_blocks: 0,
|
||||
};
|
||||
|
||||
let request = tonic::Request::new(WarmupRequest {
|
||||
batch: Some(batch),
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_total_tokens,
|
||||
})
|
||||
.inject_context();
|
||||
let response = self.stub.warmup(request).await?.into_inner();
|
||||
Ok(response.max_supported_total_tokens)
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given batch
|
||||
///
|
||||
/// Returns Generation for each request in batch
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
||||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||
let response = self.stub.prefill(request).await?.into_inner();
|
||||
Ok((
|
||||
response.generations,
|
||||
response.batch,
|
||||
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
|
||||
))
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given cached batches
|
||||
///
|
||||
/// Returns Generation for each request in batches
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batches: Vec<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||
let response = self.stub.decode(request).await?.into_inner();
|
||||
Ok((
|
||||
response.generations,
|
||||
response.batch,
|
||||
DecodeTimings::new(
|
||||
response.concat_ns,
|
||||
response.forward_ns,
|
||||
response.decode_ns,
|
||||
response.total_ns,
|
||||
),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PrefillTimings {
|
||||
pub forward: Duration,
|
||||
pub decode: Duration,
|
||||
pub total: Duration,
|
||||
}
|
||||
|
||||
impl PrefillTimings {
|
||||
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||
Self {
|
||||
forward: Duration::from_nanos(forward_ns),
|
||||
decode: Duration::from_nanos(decode_ns),
|
||||
total: Duration::from_nanos(total_ns),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DecodeTimings {
|
||||
pub concat: Option<Duration>,
|
||||
pub forward: Duration,
|
||||
pub decode: Duration,
|
||||
pub total: Duration,
|
||||
}
|
||||
|
||||
impl DecodeTimings {
|
||||
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||
Self {
|
||||
concat: concat_ns.map(Duration::from_nanos),
|
||||
forward: Duration::from_nanos(forward_ns),
|
||||
decode: Duration::from_nanos(decode_ns),
|
||||
total: Duration::from_nanos(total_ns),
|
||||
}
|
||||
}
|
||||
}
|
76
backends/v3/src/client/mod.rs
Normal file
76
backends/v3/src/client/mod.rs
Normal file
@ -0,0 +1,76 @@
|
||||
//! Text Generation gRPC client library
|
||||
|
||||
use async_trait::async_trait;
|
||||
use thiserror::Error;
|
||||
use tonic::transport;
|
||||
use tonic::Status;
|
||||
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
mod pb;
|
||||
|
||||
mod grpc_client;
|
||||
mod sharded_client;
|
||||
|
||||
pub use grpc_client::Client;
|
||||
pub use pb::generate::v3::{
|
||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||
StoppingCriteriaParameters,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Health {
|
||||
/// Check if a generate server is healthy by asking it to allocate a tensor on device
|
||||
async fn device_health(&self) -> Result<()>;
|
||||
|
||||
/// Check if a generate server is healthy by doing a forward pass.
|
||||
/// EXPENSIVE
|
||||
async fn model_health(&self) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ShardInfo {
|
||||
pub requires_padding: bool,
|
||||
pub dtype: String,
|
||||
pub device_type: String,
|
||||
pub window_size: Option<u32>,
|
||||
pub speculate: u32,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug, Clone)]
|
||||
pub enum ClientError {
|
||||
#[error("Could not connect to Text Generation server: {0}")]
|
||||
Connection(String),
|
||||
#[error("Server error: {0}")]
|
||||
Generation(String),
|
||||
#[error("Sharded results are empty")]
|
||||
EmptyResults,
|
||||
}
|
||||
|
||||
impl From<Status> for ClientError {
|
||||
fn from(err: Status) -> Self {
|
||||
let err = Self::Generation(err.message().to_string());
|
||||
tracing::error!("{err}");
|
||||
err
|
||||
}
|
||||
}
|
||||
|
||||
impl From<transport::Error> for ClientError {
|
||||
fn from(err: transport::Error) -> Self {
|
||||
let err = Self::Connection(err.to_string());
|
||||
tracing::error!("{err}");
|
||||
err
|
||||
}
|
||||
}
|
||||
|
||||
// Small convenience re-wrapping of `Chunk`.
|
||||
impl From<Chunk> for InputChunk {
|
||||
fn from(chunk: Chunk) -> Self {
|
||||
InputChunk { chunk: Some(chunk) }
|
||||
}
|
||||
}
|
||||
|
||||
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ClientError>;
|
264
backends/v3/src/client/sharded_client.rs
Normal file
264
backends/v3/src/client/sharded_client.rs
Normal file
@ -0,0 +1,264 @@
|
||||
use crate::client::{ClientError, Result};
|
||||
/// Multi shard Client
|
||||
use crate::client::{Health, ShardInfo};
|
||||
|
||||
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
|
||||
use crate::client::{
|
||||
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use crate::client::{Chunk, InfoResponse, Input};
|
||||
use async_trait::async_trait;
|
||||
use futures::future::join_all;
|
||||
use tonic::transport::Uri;
|
||||
use tracing::instrument;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Text Generation Inference gRPC multi client
|
||||
pub struct ShardedClient {
|
||||
clients: Vec<Client>,
|
||||
}
|
||||
|
||||
impl ShardedClient {
|
||||
fn new(clients: Vec<Client>) -> Self {
|
||||
Self { clients }
|
||||
}
|
||||
|
||||
/// Create a new ShardedClient from a master client. The master client will communicate with
|
||||
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
|
||||
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
||||
// Get all uris/unix sockets from the master client
|
||||
let uris = master_client.service_discovery().await?;
|
||||
let futures = uris.into_iter().map(Client::connect_uds);
|
||||
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
|
||||
Ok(Self::new(clients?))
|
||||
}
|
||||
|
||||
/// Returns a client connected to the given uri
|
||||
#[allow(dead_code)]
|
||||
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||
let master_client = Client::connect(uri).await?;
|
||||
Self::from_master_client(master_client).await
|
||||
}
|
||||
|
||||
/// Returns a client connected to the given unix socket
|
||||
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||
let master_client = Client::connect_uds(path).await?;
|
||||
Self::from_master_client(master_client).await
|
||||
}
|
||||
|
||||
/// Get the model info
|
||||
#[instrument(skip(self))]
|
||||
pub async fn info(&mut self) -> Result<ShardInfo> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.info())
|
||||
.collect();
|
||||
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
|
||||
}
|
||||
|
||||
/// GRPC health check
|
||||
#[instrument(skip(self))]
|
||||
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.health())
|
||||
.collect();
|
||||
join_all(futures).await.pop().unwrap()
|
||||
}
|
||||
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.clear_cache(batch_id))
|
||||
.collect();
|
||||
join_all(futures).await.into_iter().collect()
|
||||
}
|
||||
|
||||
/// Filter a cached batch
|
||||
#[instrument(skip(self))]
|
||||
pub async fn filter_batch(
|
||||
&mut self,
|
||||
batch_id: u64,
|
||||
request_ids: Vec<u64>,
|
||||
) -> Result<Option<CachedBatch>> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
|
||||
.collect();
|
||||
// all shards return the same message
|
||||
join_all(futures).await.pop().unwrap()
|
||||
}
|
||||
|
||||
/// Warmup on a max size batch
|
||||
///
|
||||
/// Returns the maximum amount of tokens supported by the hardware
|
||||
#[instrument(skip(self))]
|
||||
pub async fn warmup(
|
||||
&mut self,
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| {
|
||||
Box::pin(client.warmup(
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_batch_size,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
// Take the minimum value
|
||||
let results = join_all(futures)
|
||||
.await
|
||||
.into_iter()
|
||||
.collect::<Result<Vec<Option<u32>>>>()?;
|
||||
Ok(results.into_iter().flatten().min())
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given batch
|
||||
///
|
||||
/// Returns Generation for each request in batch
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
|
||||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||
.collect();
|
||||
#[allow(clippy::type_complexity)]
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||
join_all(futures).await.into_iter().collect();
|
||||
let mut results = results?;
|
||||
|
||||
let (mut generations, next_batch, mut timings) =
|
||||
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||
|
||||
// Merge generations from different model shards
|
||||
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||
generations.append(&mut shard_generations);
|
||||
// Return the timings of the slowest shard
|
||||
if shard_timings.total > timings.total {
|
||||
timings = shard_timings;
|
||||
}
|
||||
}
|
||||
Ok((generations, next_batch, timings))
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given cached batches
|
||||
///
|
||||
/// Returns Generation for each request in batches
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batches: Vec<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||
.collect();
|
||||
#[allow(clippy::type_complexity)]
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
|
||||
join_all(futures).await.into_iter().collect();
|
||||
let mut results = results?;
|
||||
|
||||
let (mut generations, next_batch, mut timings) =
|
||||
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||
|
||||
// Merge generations from different model shards
|
||||
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||
generations.append(&mut shard_generations);
|
||||
// Return the timings of the slowest shard
|
||||
if shard_timings.total > timings.total {
|
||||
timings = shard_timings;
|
||||
}
|
||||
}
|
||||
Ok((generations, next_batch, timings))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<InfoResponse> for ShardInfo {
|
||||
fn from(value: InfoResponse) -> Self {
|
||||
Self {
|
||||
requires_padding: value.requires_padding,
|
||||
dtype: value.dtype,
|
||||
device_type: value.device_type,
|
||||
window_size: value.window_size,
|
||||
speculate: value.speculate,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Health for ShardedClient {
|
||||
async fn device_health(&self) -> Result<()> {
|
||||
self.clone().health().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn model_health(&self) -> Result<()> {
|
||||
// Dummy batch of 1 token and 1 generated token
|
||||
let liveness_request = Request {
|
||||
id: u64::MAX,
|
||||
inputs: "liveness".to_string(),
|
||||
input_chunks: Some(Input {
|
||||
chunks: vec![Chunk::Text("liveness".into()).into()],
|
||||
}),
|
||||
truncate: 10,
|
||||
add_special_tokens: true,
|
||||
prefill_logprobs: false,
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 1.0,
|
||||
top_k: 0,
|
||||
top_p: 1.0,
|
||||
typical_p: 1.0,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.0,
|
||||
frequency_penalty: 0.0,
|
||||
watermark: false,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 1,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: false,
|
||||
}),
|
||||
top_n_tokens: 0,
|
||||
// Block 0 is reserved for health checks
|
||||
blocks: vec![0],
|
||||
slots: (0..16).collect(),
|
||||
prefix_len: 0,
|
||||
adapter_id: None,
|
||||
};
|
||||
let batch = Batch {
|
||||
id: u64::MAX,
|
||||
requests: vec![liveness_request],
|
||||
size: 1,
|
||||
max_tokens: 2,
|
||||
max_blocks: 1,
|
||||
};
|
||||
self.clone().prefill(batch).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
147
backends/v3/src/lib.rs
Normal file
147
backends/v3/src/lib.rs
Normal file
@ -0,0 +1,147 @@
|
||||
mod backend;
|
||||
pub mod block_allocator;
|
||||
mod client;
|
||||
mod queue;
|
||||
pub mod radix;
|
||||
|
||||
use crate::client::{ClientError, ShardedClient};
|
||||
pub(crate) use backend::BackendV3;
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
pub struct BackendInfo {
|
||||
/// Mandatory
|
||||
#[schema(example = "cuda")]
|
||||
pub model_device_type: String,
|
||||
#[schema(example = "torch.float16")]
|
||||
pub model_dtype: String,
|
||||
|
||||
/// Backend parameters
|
||||
#[schema(example = "1")]
|
||||
pub speculate: usize,
|
||||
#[schema(example = "1.2")]
|
||||
pub waiting_served_ratio: f32,
|
||||
#[schema(example = "32000")]
|
||||
pub max_batch_total_tokens: u32,
|
||||
#[schema(example = "20")]
|
||||
pub max_waiting_tokens: usize,
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub max_batch_size: Option<usize>,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn connect_backend(
|
||||
max_input_tokens: usize,
|
||||
max_total_tokens: usize,
|
||||
master_shard_uds_path: String,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<(BackendV3, BackendInfo), V3Error> {
|
||||
// Helper function
|
||||
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
||||
match max_supported_batch_total_tokens {
|
||||
// Older models do not support automatic max-batch-total-tokens
|
||||
None => {
|
||||
let max_batch_total_tokens = max_batch_total_tokens
|
||||
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
|
||||
tracing::warn!("Model does not support automatic max batch total tokens");
|
||||
Ok(max_batch_total_tokens)
|
||||
}
|
||||
// Flash attention models return their max supported total tokens
|
||||
Some(max_supported_batch_total_tokens) => {
|
||||
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
||||
if max_batch_total_tokens.is_some() {
|
||||
tracing::warn!(
|
||||
"`--max-batch-total-tokens` is deprecated for Flash \
|
||||
Attention models."
|
||||
);
|
||||
tracing::warn!(
|
||||
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
||||
);
|
||||
}
|
||||
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
||||
return Err(V3Error::NotEnoughMemory(max_total_tokens));
|
||||
}
|
||||
|
||||
Ok(max_supported_batch_total_tokens)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||
.await
|
||||
.map_err(V3Error::Connection)?;
|
||||
|
||||
// server is running on v3
|
||||
// Clear the cache; useful if the webserver rebooted
|
||||
sharded_client
|
||||
.clear_cache(None)
|
||||
.await
|
||||
.map_err(V3Error::Cache)?;
|
||||
// Get info from the shard
|
||||
let shard_info = sharded_client.info().await.map_err(V3Error::Info)?;
|
||||
|
||||
// Warmup model
|
||||
tracing::info!("Warming up model");
|
||||
let max_batch_total_tokens = check_max_batch_total_tokens(
|
||||
sharded_client
|
||||
.warmup(
|
||||
max_input_tokens as u32,
|
||||
max_batch_prefill_tokens,
|
||||
max_total_tokens as u32,
|
||||
max_batch_total_tokens.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))),
|
||||
max_batch_size,
|
||||
)
|
||||
.await
|
||||
.map_err(V3Error::Warmup)?,
|
||||
)?;
|
||||
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||
metrics::gauge!("tgi_batch_max_total_tokens").set(max_batch_total_tokens);
|
||||
|
||||
let backend_info = BackendInfo {
|
||||
waiting_served_ratio,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
model_device_type: shard_info.device_type.clone(),
|
||||
model_dtype: shard_info.dtype.clone(),
|
||||
speculate: shard_info.speculate as usize,
|
||||
};
|
||||
|
||||
let backend = BackendV3::new(
|
||||
sharded_client,
|
||||
waiting_served_ratio,
|
||||
max_input_tokens as u32,
|
||||
max_total_tokens as u32,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
shard_info.requires_padding,
|
||||
shard_info.window_size,
|
||||
shard_info.speculate,
|
||||
);
|
||||
|
||||
tracing::info!("Using backend V3");
|
||||
|
||||
Ok((backend, backend_info))
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum V3Error {
|
||||
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||
Cache(ClientError),
|
||||
#[error("Unable to connect to the Python model shards: {0}")]
|
||||
Connection(ClientError),
|
||||
#[error("Unable to get the Python model shards info: {0}")]
|
||||
Info(ClientError),
|
||||
#[error("Unable to warmup the Python model shards: {0}")]
|
||||
Warmup(ClientError),
|
||||
#[error("Not enough memory to handle `max_total_tokens={0}`")]
|
||||
NotEnoughMemory(usize),
|
||||
}
|
212
backends/v3/src/main.rs
Normal file
212
backends/v3/src/main.rs
Normal file
@ -0,0 +1,212 @@
|
||||
use clap::{Parser, Subcommand};
|
||||
use text_generation_router::{server, usage_stats};
|
||||
use text_generation_router_v3::{connect_backend, V3Error};
|
||||
use thiserror::Error;
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[command(subcommand)]
|
||||
command: Option<Commands>,
|
||||
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
max_best_of: usize,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_top_n_tokens: u32,
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
max_input_tokens: usize,
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "1.2", long, env)]
|
||||
waiting_served_ratio: f32,
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_prefill_tokens: u32,
|
||||
#[clap(long, env)]
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
#[clap(long, env)]
|
||||
max_batch_size: Option<usize>,
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
hostname: String,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||
master_shard_uds_path: String,
|
||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||
tokenizer_name: String,
|
||||
#[clap(long, env)]
|
||||
tokenizer_config_path: Option<String>,
|
||||
#[clap(long, env)]
|
||||
revision: Option<String>,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
validation_workers: usize,
|
||||
#[clap(long, env)]
|
||||
api_key: Option<String>,
|
||||
#[clap(long, env)]
|
||||
json_output: bool,
|
||||
#[clap(long, env)]
|
||||
otlp_endpoint: Option<String>,
|
||||
#[clap(default_value = "text-generation-inference.router", long, env)]
|
||||
otlp_service_name: String,
|
||||
#[clap(long, env)]
|
||||
cors_allow_origin: Option<Vec<String>>,
|
||||
#[clap(long, env)]
|
||||
ngrok: bool,
|
||||
#[clap(long, env)]
|
||||
ngrok_authtoken: Option<String>,
|
||||
#[clap(long, env)]
|
||||
ngrok_edge: Option<String>,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
messages_api_enabled: bool,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
disable_grammar_support: bool,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
#[clap(default_value = "on", long, env)]
|
||||
usage_stats: usage_stats::UsageStatsLevel,
|
||||
}
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
enum Commands {
|
||||
PrintSchema,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), RouterError> {
|
||||
// Get args
|
||||
let args = Args::parse();
|
||||
// Pattern match configuration
|
||||
let Args {
|
||||
command,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
hostname,
|
||||
port,
|
||||
master_shard_uds_path,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
validation_workers,
|
||||
api_key,
|
||||
json_output,
|
||||
otlp_endpoint,
|
||||
otlp_service_name,
|
||||
cors_allow_origin,
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
usage_stats,
|
||||
} = args;
|
||||
|
||||
if let Some(Commands::PrintSchema) = command {
|
||||
use utoipa::OpenApi;
|
||||
let api_doc = text_generation_router::server::ApiDoc::openapi();
|
||||
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
|
||||
println!("{}", api_doc);
|
||||
std::process::exit(0);
|
||||
};
|
||||
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||
|
||||
// Validate args
|
||||
if max_input_tokens >= max_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||
}
|
||||
|
||||
if validation_workers == 0 {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`validation_workers` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(max_batch_size) = max_batch_size {
|
||||
if max_batch_size == 0 {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`max_batch_size` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let (backend, _backend_info) = connect_backend(
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
master_shard_uds_path,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Run server
|
||||
server::run(
|
||||
backend,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
validation_workers,
|
||||
api_key,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
hostname,
|
||||
port,
|
||||
cors_allow_origin,
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
usage_stats,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum RouterError {
|
||||
#[error("Argument validation error: {0}")]
|
||||
ArgumentValidation(String),
|
||||
#[error("Backend failed: {0}")]
|
||||
Backend(#[from] V3Error),
|
||||
#[error("WebServer error: {0}")]
|
||||
WebServer(#[from] server::WebServerError),
|
||||
#[error("Tokio runtime failed to start: {0}")]
|
||||
Tokio(#[from] std::io::Error),
|
||||
}
|
824
backends/v3/src/queue.rs
Normal file
824
backends/v3/src/queue.rs
Normal file
@ -0,0 +1,824 @@
|
||||
use crate::block_allocator::{BlockAllocation, BlockAllocator};
|
||||
use crate::client;
|
||||
use crate::client::{
|
||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::{max, min};
|
||||
use std::collections::VecDeque;
|
||||
use text_generation_router::infer::InferError;
|
||||
use text_generation_router::infer::InferStreamResponse;
|
||||
use text_generation_router::validation::{
|
||||
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
|
||||
ValidStoppingParameters,
|
||||
};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
|
||||
/// Queue entry
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Entry {
|
||||
/// Request
|
||||
pub request: ValidGenerateRequest,
|
||||
/// Response sender to communicate between the Infer struct and the batching_task
|
||||
pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,
|
||||
/// Span that will live as long as entry
|
||||
pub span: Span,
|
||||
/// Temporary span used as a guard when logging inference, wait times...
|
||||
pub temp_span: Option<Span>,
|
||||
/// Instant when this entry was queued
|
||||
pub queue_time: Instant,
|
||||
/// Instant when this entry was added to a batch
|
||||
pub batch_time: Option<Instant>,
|
||||
/// Block Allocation
|
||||
pub block_allocation: Option<BlockAllocation>,
|
||||
}
|
||||
|
||||
/// Request Queue
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Queue {
|
||||
/// Channel to communicate with the background queue task
|
||||
queue_sender: mpsc::UnboundedSender<QueueCommand>,
|
||||
}
|
||||
|
||||
impl Queue {
|
||||
pub(crate) fn new(
|
||||
requires_padding: bool,
|
||||
block_size: u32,
|
||||
prefix_caching: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_input_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
) -> Self {
|
||||
// Create channel
|
||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||
|
||||
// Launch background queue task
|
||||
tokio::spawn(queue_task(
|
||||
requires_padding,
|
||||
block_size,
|
||||
prefix_caching,
|
||||
window_size,
|
||||
speculate,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_total_tokens,
|
||||
queue_receiver,
|
||||
));
|
||||
|
||||
Self { queue_sender }
|
||||
}
|
||||
|
||||
/// Append an entry to the queue
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) fn append(&self, entry: Entry) {
|
||||
// Send append command to the background task managing the state
|
||||
// Unwrap is safe here
|
||||
self.queue_sender
|
||||
.send(QueueCommand::Append(Box::new(entry), Span::current()))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Get the next batch
|
||||
#[instrument(skip(self))]
|
||||
pub(crate) async fn next_batch(
|
||||
&self,
|
||||
min_size: Option<usize>,
|
||||
max_size: Option<usize>,
|
||||
prefill_token_budget: u32,
|
||||
token_budget: u32,
|
||||
) -> Option<NextBatch> {
|
||||
// Create response channel
|
||||
let (response_sender, response_receiver) = oneshot::channel();
|
||||
// Send next batch command to the background task managing the state
|
||||
// Unwrap is safe here
|
||||
self.queue_sender
|
||||
.send(QueueCommand::NextBatch {
|
||||
min_size,
|
||||
max_size,
|
||||
prefill_token_budget,
|
||||
token_budget,
|
||||
response_sender,
|
||||
span: Span::current(),
|
||||
})
|
||||
.unwrap();
|
||||
// Await on response channel
|
||||
// Unwrap is safe here
|
||||
response_receiver.await.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
// Background task responsible of the queue state
|
||||
async fn queue_task(
|
||||
requires_padding: bool,
|
||||
block_size: u32,
|
||||
prefix_caching: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_input_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||
) {
|
||||
let mut state = State::new(
|
||||
requires_padding,
|
||||
block_size,
|
||||
prefix_caching,
|
||||
window_size,
|
||||
speculate,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_total_tokens,
|
||||
);
|
||||
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
match cmd {
|
||||
QueueCommand::Append(entry, span) => {
|
||||
span.in_scope(|| state.append(*entry));
|
||||
metrics::gauge!("tgi_queue_size").increment(1.0);
|
||||
}
|
||||
QueueCommand::NextBatch {
|
||||
min_size,
|
||||
max_size,
|
||||
prefill_token_budget,
|
||||
token_budget,
|
||||
response_sender,
|
||||
span,
|
||||
} => {
|
||||
let next_batch = state
|
||||
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
|
||||
.instrument(span)
|
||||
.await;
|
||||
response_sender.send(next_batch).unwrap();
|
||||
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Queue State
|
||||
#[derive(Debug)]
|
||||
struct State {
|
||||
/// Queue entries organized in a Vec
|
||||
entries: VecDeque<(u64, Entry)>,
|
||||
|
||||
/// Id of the next entry
|
||||
next_id: u64,
|
||||
|
||||
/// Id of the next batch
|
||||
next_batch_id: u64,
|
||||
|
||||
/// Paged Attention block size
|
||||
block_size: u32,
|
||||
|
||||
/// Sliding window
|
||||
window_size: Option<u32>,
|
||||
|
||||
/// Speculation amount
|
||||
speculate: u32,
|
||||
|
||||
/// Paged Attention Block Allocation
|
||||
block_allocator: Option<BlockAllocator>,
|
||||
|
||||
/// Require padding
|
||||
requires_padding: bool,
|
||||
|
||||
/// max input tokens
|
||||
max_input_tokens: u32,
|
||||
|
||||
/// max total tokens,
|
||||
max_total_tokens: u32,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn new(
|
||||
requires_padding: bool,
|
||||
block_size: u32,
|
||||
prefix_caching: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_input_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
) -> Self {
|
||||
let block_allocator = (!requires_padding).then(|| {
|
||||
BlockAllocator::new(
|
||||
max_batch_total_tokens,
|
||||
block_size,
|
||||
prefix_caching,
|
||||
window_size,
|
||||
)
|
||||
});
|
||||
|
||||
Self {
|
||||
entries: VecDeque::with_capacity(128),
|
||||
next_id: 0,
|
||||
next_batch_id: 0,
|
||||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
block_allocator,
|
||||
requires_padding,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
/// Append an entry to the queue
|
||||
fn append(&mut self, mut entry: Entry) {
|
||||
// Create a span that will live as long as the entry is in the queue waiting to be batched
|
||||
let queue_span = info_span!(parent: &entry.span, "queued");
|
||||
entry.temp_span = Some(queue_span);
|
||||
|
||||
// Push entry in the queue
|
||||
self.entries.push_back((self.next_id, entry));
|
||||
self.next_id += 1;
|
||||
}
|
||||
|
||||
// Get the next batch
|
||||
async fn next_batch(
|
||||
&mut self,
|
||||
min_size: Option<usize>,
|
||||
max_size: Option<usize>,
|
||||
prefill_token_budget: u32,
|
||||
token_budget: u32,
|
||||
) -> Option<NextBatch> {
|
||||
if self.entries.is_empty() {
|
||||
tracing::debug!("No queue");
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check if we have enough entries
|
||||
if let Some(min_size) = min_size {
|
||||
if self.entries.len() < min_size {
|
||||
tracing::debug!("Not enough entries");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(max_size) = max_size {
|
||||
if max_size == 0 {
|
||||
tracing::debug!("No capacity");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
// Pad prefill_token_budget to be a multiple of block size
|
||||
let prefill_token_budget =
|
||||
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
|
||||
|
||||
// Create span for this batch to add context to inference calls
|
||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||
next_batch_span.follows_from(Span::current());
|
||||
|
||||
let mut batch = Vec::with_capacity(self.entries.len());
|
||||
let mut max_input_length = 0;
|
||||
let mut prefill_tokens: u32 = 0;
|
||||
let mut decode_tokens: u32 = 0;
|
||||
let mut max_blocks = 0;
|
||||
|
||||
// Pop entries starting from the front of the queue
|
||||
'entry_loop: while let Some((id, entry)) = self.entries.pop_front() {
|
||||
// Filter entries where the response receiver was dropped (== entries where the request
|
||||
// was dropped by the client)
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
tracing::debug!("Dropping entry");
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_allocation = match &self.block_allocator {
|
||||
None => {
|
||||
// We pad to max input length in the Python shards
|
||||
// We need to take these padding tokens into the equation
|
||||
if self.requires_padding {
|
||||
prefill_tokens = (batch.len() + 1) as u32 * self.max_input_tokens;
|
||||
} else{
|
||||
max_input_length = max_input_length.max(entry.request.input_length);
|
||||
prefill_tokens = (batch.len() + 1) as u32 * max_input_length;
|
||||
}
|
||||
|
||||
if self.requires_padding {
|
||||
decode_tokens = (batch.len() + 1) as u32 * (self.max_total_tokens - self.max_input_tokens);
|
||||
} else {
|
||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||
}
|
||||
|
||||
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
|
||||
|
||||
if prefill_tokens > prefill_token_budget || total_tokens > token_budget {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||
self.entries.push_front((id, entry));
|
||||
break 'entry_loop;
|
||||
}
|
||||
None
|
||||
}
|
||||
Some(_block_allocator) => {
|
||||
prefill_tokens += entry.request.input_length;
|
||||
let max_new_tokens = match self.window_size {
|
||||
None => entry.request.stopping_parameters.max_new_tokens,
|
||||
Some(window_size) => min(
|
||||
window_size.saturating_sub(entry.request.input_length),
|
||||
entry.request.stopping_parameters.max_new_tokens,
|
||||
),
|
||||
};
|
||||
decode_tokens += max_new_tokens;
|
||||
|
||||
if prefill_tokens > prefill_token_budget
|
||||
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
|
||||
{
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||
self.entries.push_front((id, entry));
|
||||
break;
|
||||
}
|
||||
|
||||
let tokens = entry.request.input_length
|
||||
+ entry.request.stopping_parameters.max_new_tokens
|
||||
+ self.speculate
|
||||
- 1;
|
||||
|
||||
// If users wants the prefill logprobs, we cannot reuse the cache.
|
||||
// So no input_ids for the radix tree.
|
||||
let input_ids = if entry.request.decoder_input_details {
|
||||
None
|
||||
} else {
|
||||
entry.request.input_ids.clone()
|
||||
};
|
||||
|
||||
Some((tokens, input_ids))
|
||||
}
|
||||
};
|
||||
batch.push((id, entry, block_allocation));
|
||||
if Some(batch.len()) == max_size {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Empty batch
|
||||
if batch.is_empty() {
|
||||
tracing::debug!("Filterered out all entries");
|
||||
return None;
|
||||
}
|
||||
|
||||
// XXX We haven't allocated yet, so we're allowed to ditch the results.
|
||||
// Check if our batch is big enough
|
||||
if let Some(min_size) = min_size {
|
||||
// Batch is too small
|
||||
if batch.len() < min_size {
|
||||
// Add back entries to the queue in the correct order
|
||||
for (id, entry, _) in batch.into_iter().rev() {
|
||||
self.entries.push_front((id, entry));
|
||||
}
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
||||
let mut batch_entries =
|
||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||
|
||||
for (id, mut entry, block_allocation) in batch {
|
||||
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
|
||||
(block_allocation, &self.block_allocator)
|
||||
{
|
||||
tracing::debug!("Allocating {tokens} with {input_ids:?}");
|
||||
match block_allocator.allocate(tokens, input_ids).await {
|
||||
None => {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: not enough free blocks");
|
||||
self.entries.push_front((id, entry));
|
||||
continue;
|
||||
}
|
||||
Some(block_allocation) => {
|
||||
tracing::debug!("Allocation: {block_allocation:?}");
|
||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||
Some(block_allocation)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
tracing::debug!("Accepting entry");
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
// Add relationships
|
||||
next_batch_span.follows_from(&entry_batch_span);
|
||||
entry_batch_span.follows_from(&next_batch_span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
|
||||
let (blocks, slots, prefix_len) = match &block_allocation {
|
||||
None => (Vec::new(), Vec::new(), 0),
|
||||
Some(block_allocation) => (
|
||||
block_allocation.blocks.clone(),
|
||||
block_allocation.slots.clone(),
|
||||
block_allocation.prefix_len,
|
||||
),
|
||||
};
|
||||
|
||||
entry.block_allocation = block_allocation;
|
||||
|
||||
batch_requests.push(Request {
|
||||
id,
|
||||
prefill_logprobs: entry.request.decoder_input_details,
|
||||
input_chunks: Some(client::Input {
|
||||
chunks: entry
|
||||
.request
|
||||
.inputs
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|c| client::InputChunk {
|
||||
chunk: Some(match c {
|
||||
Chunk::Text(text) => client::Chunk::Text(text),
|
||||
Chunk::Image(image) => client::Chunk::Image(client::Image {
|
||||
data: image.data,
|
||||
mimetype: image.mimetype,
|
||||
}),
|
||||
}),
|
||||
})
|
||||
.collect(),
|
||||
}),
|
||||
inputs: entry.request.inputs.chunks_to_string(),
|
||||
truncate: entry.request.truncate,
|
||||
add_special_tokens: entry.request.add_special_tokens,
|
||||
parameters: Some(NextTokenChooserParameters::from(
|
||||
entry.request.parameters.clone(),
|
||||
)),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters::from(
|
||||
entry.request.stopping_parameters.clone(),
|
||||
)),
|
||||
top_n_tokens: entry.request.top_n_tokens,
|
||||
blocks,
|
||||
slots,
|
||||
prefix_len,
|
||||
adapter_id: entry.request.adapter_id.clone(),
|
||||
});
|
||||
// Set batch_time
|
||||
entry.batch_time = Some(Instant::now());
|
||||
// Insert in batch_entries IntMap
|
||||
batch_entries.insert(id, entry);
|
||||
}
|
||||
|
||||
// Empty batch
|
||||
if batch_requests.is_empty() {
|
||||
tracing::debug!("Filterered out all entries");
|
||||
return None;
|
||||
}
|
||||
|
||||
// Final batch size
|
||||
let size = batch_requests.len() as u32;
|
||||
next_batch_span.record("batch_size", size);
|
||||
|
||||
let batch = Batch {
|
||||
id: self.next_batch_id,
|
||||
requests: batch_requests,
|
||||
size,
|
||||
max_tokens: (prefill_tokens + decode_tokens),
|
||||
max_blocks,
|
||||
};
|
||||
// Increment batch id
|
||||
self.next_batch_id += 1;
|
||||
|
||||
metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
|
||||
|
||||
Some((batch_entries, batch, next_batch_span))
|
||||
}
|
||||
}
|
||||
|
||||
type NextBatch = (IntMap<u64, Entry>, Batch, Span);
|
||||
|
||||
#[derive(Debug)]
|
||||
enum QueueCommand {
|
||||
Append(Box<Entry>, Span),
|
||||
NextBatch {
|
||||
min_size: Option<usize>,
|
||||
max_size: Option<usize>,
|
||||
prefill_token_budget: u32,
|
||||
token_budget: u32,
|
||||
response_sender: oneshot::Sender<Option<NextBatch>>,
|
||||
span: Span,
|
||||
},
|
||||
}
|
||||
|
||||
impl From<ValidParameters> for NextTokenChooserParameters {
|
||||
fn from(value: ValidParameters) -> Self {
|
||||
let (grammar, grammar_type) = match value.grammar {
|
||||
None => (String::new(), GrammarType::None),
|
||||
|
||||
Some(grammar) => match grammar {
|
||||
ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json),
|
||||
ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex),
|
||||
},
|
||||
};
|
||||
|
||||
Self {
|
||||
temperature: value.temperature,
|
||||
top_k: value.top_k,
|
||||
top_p: value.top_p,
|
||||
typical_p: value.typical_p,
|
||||
do_sample: value.do_sample,
|
||||
seed: value.seed,
|
||||
repetition_penalty: value.repetition_penalty,
|
||||
frequency_penalty: value.frequency_penalty,
|
||||
watermark: value.watermark,
|
||||
grammar,
|
||||
grammar_type: grammar_type.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
|
||||
fn from(value: ValidStoppingParameters) -> Self {
|
||||
Self {
|
||||
max_new_tokens: value.max_new_tokens,
|
||||
stop_sequences: value.stop_sequences,
|
||||
ignore_eos_token: value.ignore_eos_token,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use tracing::info_span;
|
||||
|
||||
fn default_entry() -> (
|
||||
Entry,
|
||||
mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,
|
||||
) {
|
||||
let (response_tx, receiver_tx) = mpsc::unbounded_channel();
|
||||
|
||||
let entry = Entry {
|
||||
request: ValidGenerateRequest {
|
||||
inputs: vec![],
|
||||
input_ids: Some(Arc::new(vec![])),
|
||||
input_length: 0,
|
||||
add_special_tokens: true,
|
||||
truncate: 0,
|
||||
decoder_input_details: false,
|
||||
parameters: ValidParameters {
|
||||
temperature: 0.0,
|
||||
top_k: 0,
|
||||
top_p: 0.0,
|
||||
typical_p: 0.0,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 0.0,
|
||||
frequency_penalty: 0.0,
|
||||
watermark: false,
|
||||
grammar: None,
|
||||
},
|
||||
stopping_parameters: ValidStoppingParameters {
|
||||
ignore_eos_token: false,
|
||||
max_new_tokens: 1,
|
||||
stop_sequences: vec![],
|
||||
},
|
||||
top_n_tokens: 0,
|
||||
adapter_id: None,
|
||||
},
|
||||
response_tx,
|
||||
span: info_span!("entry"),
|
||||
temp_span: None,
|
||||
queue_time: Instant::now(),
|
||||
batch_time: None,
|
||||
block_allocation: None,
|
||||
};
|
||||
(entry, receiver_tx)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_append() {
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
let (entry, _guard) = default_entry();
|
||||
|
||||
assert_eq!(state.next_id, 0);
|
||||
assert_eq!(state.entries.len(), 0);
|
||||
|
||||
state.append(entry);
|
||||
|
||||
assert_eq!(state.next_id, 1);
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
let (id, _) = state.entries.remove(0).unwrap();
|
||||
assert_eq!(id, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_empty() {
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
|
||||
assert!(state.next_batch(None, None, 1, 1).await.is_none());
|
||||
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_min_size() {
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||
assert!(entries.get(&1).unwrap().batch_time.is_some());
|
||||
assert_eq!(batch.id, 0);
|
||||
assert_eq!(batch.size, 2);
|
||||
|
||||
assert_eq!(state.next_id, 2);
|
||||
assert_eq!(state.entries.len(), 0);
|
||||
assert_eq!(state.next_batch_id, 1);
|
||||
|
||||
let (entry3, _guard3) = default_entry();
|
||||
state.append(entry3);
|
||||
|
||||
assert!(state.next_batch(Some(2), None, 2, 2).await.is_none());
|
||||
|
||||
assert_eq!(state.next_id, 3);
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
let (id, _) = state.entries.remove(0).unwrap();
|
||||
assert_eq!(id, 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_max_size() {
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||
assert_eq!(batch.id, 0);
|
||||
assert_eq!(batch.size, 1);
|
||||
|
||||
assert_eq!(state.next_id, 2);
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
assert_eq!(state.next_batch_id, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_token_budget() {
|
||||
let mut state = State::new(false, 1, false, None, 0, 2);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert_eq!(batch.id, 0);
|
||||
assert_eq!(batch.size, 1);
|
||||
|
||||
assert_eq!(state.next_id, 2);
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
assert_eq!(state.next_batch_id, 1);
|
||||
|
||||
let (entry3, _guard3) = default_entry();
|
||||
state.append(entry3);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.contains_key(&2));
|
||||
assert_eq!(batch.id, 1);
|
||||
assert_eq!(batch.size, 2);
|
||||
|
||||
assert_eq!(state.next_id, 3);
|
||||
assert_eq!(state.entries.len(), 0);
|
||||
assert_eq!(state.next_batch_id, 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_append() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let (entry, _guard) = default_entry();
|
||||
queue.append(entry);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_empty() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
|
||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_min_size() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
queue.append(entry2);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||
assert!(entries.get(&1).unwrap().batch_time.is_some());
|
||||
assert_eq!(batch.id, 0);
|
||||
assert_eq!(batch.size, 2);
|
||||
|
||||
let (entry3, _guard3) = default_entry();
|
||||
queue.append(entry3);
|
||||
|
||||
// Not enough requests pending
|
||||
assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none());
|
||||
// Not enough token budget
|
||||
assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none());
|
||||
// Ok
|
||||
let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap();
|
||||
assert_eq!(entries2.len(), 1);
|
||||
assert!(entries2.contains_key(&2));
|
||||
assert!(entries2.get(&2).unwrap().batch_time.is_some());
|
||||
assert_eq!(batch2.id, 1);
|
||||
assert_eq!(batch2.size, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_max_size() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
queue.append(entry2);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||
assert_eq!(batch.id, 0);
|
||||
assert_eq!(batch.size, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_budget() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
queue.append(entry2);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert_eq!(batch.id, 0);
|
||||
assert_eq!(batch.size, 1);
|
||||
|
||||
let (entry3, _guard3) = default_entry();
|
||||
queue.append(entry3);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.contains_key(&2));
|
||||
assert_eq!(batch.id, 1);
|
||||
assert_eq!(batch.size, 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_speculate() {
|
||||
let queue = Queue::new(false, 1, false, None, 2, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
queue.append(entry2);
|
||||
|
||||
// Budget of 1 is not enough
|
||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.contains_key(&1));
|
||||
assert_eq!(batch.id, 0);
|
||||
assert_eq!(batch.size, 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_dropped_receiver() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let (entry, _) = default_entry();
|
||||
queue.append(entry);
|
||||
|
||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||
}
|
||||
}
|
876
backends/v3/src/radix.rs
Normal file
876
backends/v3/src/radix.rs
Normal file
@ -0,0 +1,876 @@
|
||||
use crate::block_allocator::{Allocator, BlockAllocation};
|
||||
use slotmap::{DefaultKey, SlotMap};
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::{
|
||||
collections::{BTreeSet, HashMap},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
fn hash(slice: &[u32]) -> u64 {
|
||||
assert!(!slice.is_empty());
|
||||
if slice.len() == 1 {
|
||||
slice[0] as u64
|
||||
} else {
|
||||
let mut s = std::hash::DefaultHasher::new();
|
||||
slice.hash(&mut s);
|
||||
s.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RadixAllocator {
|
||||
allocation_id: u64,
|
||||
|
||||
allocations: HashMap<u64, RadixAllocation>,
|
||||
|
||||
cache_blocks: RadixTrie,
|
||||
|
||||
/// Blocks that are immediately available for allocation.
|
||||
free_blocks: Vec<u32>,
|
||||
|
||||
#[allow(dead_code)]
|
||||
// This isn't used because the prefix need to match without the windowing
|
||||
// mecanism. This at worst is overallocating, not necessarily being wrong.
|
||||
window_size: Option<u32>,
|
||||
|
||||
block_size: u32,
|
||||
}
|
||||
|
||||
impl RadixAllocator {
|
||||
pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self {
|
||||
RadixAllocator {
|
||||
allocation_id: 0,
|
||||
allocations: HashMap::new(),
|
||||
cache_blocks: RadixTrie::new(block_size as usize),
|
||||
|
||||
// Block 0 is reserved for health checks.
|
||||
free_blocks: (1..n_blocks).collect(),
|
||||
window_size,
|
||||
block_size,
|
||||
}
|
||||
}
|
||||
|
||||
fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option<Vec<u32>> {
|
||||
if self.free_blocks.len() < n_blocks_needed {
|
||||
// This is a bit annoying, we first extend the free list and then
|
||||
// split it off again below. This is because we need to put it on
|
||||
// the free list if we cannot allocate enough blocks. This is only
|
||||
// temporary, the trie needs to be able to report whether it can
|
||||
// allocate the requested amount. Just not implemented yet.
|
||||
tracing::debug!(
|
||||
"Free blocks {} need {n_blocks_needed}",
|
||||
self.free_blocks.len()
|
||||
);
|
||||
self.free_blocks.extend(
|
||||
self.cache_blocks
|
||||
.evict(n_blocks_needed - self.free_blocks.len()),
|
||||
);
|
||||
}
|
||||
|
||||
if self.free_blocks.len() >= n_blocks_needed {
|
||||
Some(
|
||||
self.free_blocks
|
||||
.split_off(self.free_blocks.len() - n_blocks_needed),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Allocator trait
|
||||
impl Allocator for RadixAllocator {
|
||||
fn allocate(
|
||||
&mut self,
|
||||
tokens: u32,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation> {
|
||||
let mut blocks = vec![];
|
||||
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
|
||||
let node_id = self
|
||||
.cache_blocks
|
||||
.find(prefill_tokens.as_slice(), &mut blocks);
|
||||
node_id
|
||||
} else {
|
||||
self.cache_blocks.root_id()
|
||||
};
|
||||
|
||||
// Even if this allocation fails below, we need to increase he
|
||||
// refcount to ensure that the prefix that was found is not evicted.
|
||||
self.cache_blocks
|
||||
.incref(prefix_node)
|
||||
.expect("Failed to increment refcount");
|
||||
|
||||
let prefix_len = blocks.len() * self.block_size as usize;
|
||||
let suffix_len = tokens - prefix_len as u32;
|
||||
|
||||
let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size;
|
||||
|
||||
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
|
||||
|
||||
match self.alloc_or_reclaim(suffix_blocks as usize) {
|
||||
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
||||
None => {
|
||||
tracing::debug!("Cannot allocate {:?}", self.cache_blocks);
|
||||
tracing::debug!("Found {prefix_len} prefix tokens need {suffix_blocks} suffix blocks for {tokens} tokens");
|
||||
tracing::debug!("Block size {}", self.block_size);
|
||||
self.cache_blocks
|
||||
.decref(prefix_node)
|
||||
.expect("Failed to decrement refcount");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
// 1:1 mapping of blocks and slots.
|
||||
let slots = if self.block_size == 1 {
|
||||
blocks.clone()
|
||||
} else {
|
||||
let mut slots = Vec::with_capacity(blocks.len() * self.block_size as usize);
|
||||
'slots: for block_id in &blocks {
|
||||
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||
slots.push(s);
|
||||
if slots.len() as u32 == tokens {
|
||||
break 'slots;
|
||||
}
|
||||
}
|
||||
}
|
||||
slots
|
||||
};
|
||||
|
||||
let allocation = RadixAllocation {
|
||||
prefix_node,
|
||||
cached_prefix_len: prefix_len,
|
||||
prefill_tokens: prefill_tokens.clone(),
|
||||
};
|
||||
|
||||
self.allocation_id += 1;
|
||||
self.allocations.insert(self.allocation_id, allocation);
|
||||
|
||||
Some(BlockAllocation {
|
||||
allocation_id: self.allocation_id,
|
||||
block_allocator: None,
|
||||
blocks,
|
||||
slots,
|
||||
prefix_len: prefix_len as u32,
|
||||
})
|
||||
}
|
||||
|
||||
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
|
||||
let allocation = match self.allocations.remove(&allocation_id) {
|
||||
Some(allocation) => allocation,
|
||||
None => unreachable!("Tried to free an unknown allocation."),
|
||||
};
|
||||
|
||||
self.cache_blocks
|
||||
.decref(allocation.prefix_node)
|
||||
.expect("Failed to decrement refcount");
|
||||
|
||||
if let Some(prefill_tokens) = allocation.prefill_tokens {
|
||||
let prefill_tokens = prefill_tokens.as_slice();
|
||||
|
||||
// If there are prefill tokens that did not come from the cache,
|
||||
// add them to the cache.
|
||||
if prefill_tokens.len() > allocation.cached_prefix_len {
|
||||
let aligned =
|
||||
(prefill_tokens.len() / self.block_size as usize) * self.block_size as usize;
|
||||
if aligned > 0 {
|
||||
let prefix_len = self
|
||||
.cache_blocks
|
||||
.insert(
|
||||
&prefill_tokens[..aligned],
|
||||
&blocks[..aligned / self.block_size as usize],
|
||||
)
|
||||
// Unwrap, failing is a programming error.
|
||||
.expect("Failed to store prefill tokens");
|
||||
// We can have a prefill with the following structure:
|
||||
//
|
||||
// |---| From the prefix cache.
|
||||
// A B C D E F G
|
||||
//|--------| Found in the trie during insertion.
|
||||
//
|
||||
// This means that while processing this request there was a
|
||||
// partially overlapping request that had A..=E in its
|
||||
// prefill. In this case we need to free the blocks D E.
|
||||
if prefix_len > allocation.cached_prefix_len {
|
||||
self.free_blocks.extend(
|
||||
&blocks[allocation.cached_prefix_len / self.block_size as usize
|
||||
..prefix_len / self.block_size as usize],
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Free non-prefill blocks.
|
||||
self.free_blocks
|
||||
.extend(&blocks[prefill_tokens.len() / self.block_size as usize..]);
|
||||
} else {
|
||||
self.free_blocks.extend(blocks);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct RadixAllocation {
|
||||
prefix_node: NodeId,
|
||||
cached_prefix_len: usize,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
}
|
||||
|
||||
// Radix trie that is heavily inspired by radix attention from sglang.
|
||||
//
|
||||
// The trie is optimized for prefix caching:
|
||||
//
|
||||
// - A normal radix trie stores discrete values. In this radix trie,
|
||||
// inserting *abc* with value *xyz* will also enable lookup for
|
||||
// *a* (*x*) and *ab* (*xy*).
|
||||
// - As a result, every value is required to have the same length as
|
||||
// the key.
|
||||
// - We store additional information in each node, such as last access
|
||||
// time and a reference count.
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum TrieError {
|
||||
InvalidNodeId,
|
||||
RefCountUnderflow,
|
||||
}
|
||||
|
||||
pub type NodeId = DefaultKey;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RadixTrie {
|
||||
/// Identifier of the root nod.
|
||||
root: DefaultKey,
|
||||
|
||||
/// Leave node identifiers ordered by increasing recency.
|
||||
leaves: BTreeSet<(u64, NodeId)>,
|
||||
|
||||
/// All trie nodes.
|
||||
nodes: SlotMap<NodeId, TrieNode>,
|
||||
|
||||
/// Time as a monotonically increating counter to avoid the system
|
||||
/// call that a real time lookup would require.
|
||||
time: u64,
|
||||
|
||||
/// All blocks need to be aligned with this
|
||||
block_size: usize,
|
||||
}
|
||||
|
||||
impl RadixTrie {
|
||||
/// Construct a new radix trie.
|
||||
pub fn new(block_size: usize) -> Self {
|
||||
let root = TrieNode::new(vec![], vec![], 0, None);
|
||||
let mut nodes = SlotMap::new();
|
||||
let root = nodes.insert(root);
|
||||
RadixTrie {
|
||||
leaves: BTreeSet::new(),
|
||||
nodes,
|
||||
root,
|
||||
time: 0,
|
||||
block_size,
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the prefix of the given tokens.
|
||||
///
|
||||
/// The blocks corresponding to the part of the prefix that could be found
|
||||
/// are written to `blocks`. The number of blocks is in `0..=tokens.len()`.
|
||||
/// Returns the identifier of the trie node that contains the longest
|
||||
/// prefix. The node identifier can be used by callers to e.g. increase its
|
||||
/// reference count.
|
||||
///
|
||||
/// Using this method will update the access time of the traversed nodes.
|
||||
pub fn find(&mut self, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
||||
self.time += 1;
|
||||
self.find_(self.root, key, blocks)
|
||||
}
|
||||
|
||||
/// Find worker.
|
||||
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
||||
let node = &self.nodes[node_id];
|
||||
|
||||
if key.len() >= self.block_size {
|
||||
let node_key = hash(&key[..self.block_size]);
|
||||
if let Some(&child_id) = node.children.get(&node_key) {
|
||||
self.update_access_time(child_id);
|
||||
let child = self.nodes.get(child_id).expect("Invalid child identifier");
|
||||
let shared_prefix_len = shared_prefix(&child.key, key, self.block_size);
|
||||
assert_eq!(shared_prefix_len % self.block_size, 0);
|
||||
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
|
||||
|
||||
let key = &key[shared_prefix_len..];
|
||||
if !key.is_empty() {
|
||||
node_id = self.find_(child_id, key, blocks);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
node_id
|
||||
}
|
||||
|
||||
/// Decrease the reference count of a node.
|
||||
pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> {
|
||||
// We don't care about refcounting for root, since it will never
|
||||
// be evicted.
|
||||
if node_id == self.root {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let node = self
|
||||
.nodes
|
||||
.get_mut(node_id)
|
||||
.ok_or(TrieError::InvalidNodeId)?;
|
||||
if node.ref_count == 0 {
|
||||
return Err(TrieError::RefCountUnderflow);
|
||||
}
|
||||
|
||||
node.ref_count -= 1;
|
||||
if node.ref_count == 0 {
|
||||
assert!(
|
||||
node.children.is_empty(),
|
||||
"Nodes with children must have refcount > 0"
|
||||
);
|
||||
|
||||
self.leaves.insert((node.last_accessed, node_id));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Increase the reference count of a node.
|
||||
pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> {
|
||||
if node_id == self.root {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let node = self
|
||||
.nodes
|
||||
.get_mut(node_id)
|
||||
.ok_or(TrieError::InvalidNodeId)?;
|
||||
if node.ref_count == 0 {
|
||||
self.leaves.remove(&(node.last_accessed, node_id));
|
||||
}
|
||||
node.ref_count += 1;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Evict `n_blocks` from the trie.
|
||||
///
|
||||
/// Returns the evicted blocks. When the length is less than `n_blocks`,
|
||||
/// not enough blocks could be evicted.
|
||||
pub fn evict(&mut self, n_blocks: usize) -> Vec<u32> {
|
||||
// NOTE: we don't return Result here. If any of the unwrapping fails,
|
||||
// it's a programming error in the trie implementation, not a user
|
||||
// error caused by e.g. an invalid argument.
|
||||
|
||||
// TODO: add some bookkeeping in the future to check whether we can
|
||||
// evict n_blocks and return `None` if we can't. We are now needlessly
|
||||
// evicting prefixes from the cache in such a case.
|
||||
let mut evicted = Vec::new();
|
||||
tracing::debug!("Evicting in search of {n_blocks}");
|
||||
|
||||
while let Some((last_access, node_id)) = self.leaves.pop_first() {
|
||||
let blocks_needed = n_blocks.saturating_sub(evicted.len());
|
||||
tracing::debug!("Evicting node {node_id:?} ");
|
||||
|
||||
let node = self.nodes.get(node_id).expect("Leave does not exist");
|
||||
assert_eq!(
|
||||
node.ref_count, 0,
|
||||
"Leaf must have refcount of 0, got {}",
|
||||
node.ref_count
|
||||
);
|
||||
|
||||
if blocks_needed >= node.blocks.len() {
|
||||
// We need to evict the whole node if we need more blocks than it has.
|
||||
let node = self.remove_node(node_id);
|
||||
evicted.extend(node.blocks);
|
||||
|
||||
if evicted.len() >= n_blocks {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
// The node has more blocks than needed, so we'll just remove
|
||||
// the required number of blocks and leave the remaining blocks
|
||||
// untouched.
|
||||
let node = self.nodes.get_mut(node_id).expect("Leave does not exist");
|
||||
|
||||
let truncate_blocks = node.blocks.len() - blocks_needed;
|
||||
let truncate_tokens = truncate_blocks * self.block_size;
|
||||
node.key.truncate(truncate_tokens);
|
||||
evicted.extend(node.blocks.split_off(truncate_blocks));
|
||||
self.leaves.insert((last_access, node_id));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
evicted
|
||||
}
|
||||
|
||||
/// Insert a prefill along with its blocks.
|
||||
///
|
||||
/// This method returns the length of the prefix that was already
|
||||
/// in the trie. E.g. if the length is 10, this means that for
|
||||
/// the first 10 elements of the tree **the blocks are not updated**.
|
||||
pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result<usize, TrieError> {
|
||||
self.time += 1;
|
||||
let common = self.insert_(self.root, tokens, blocks)?;
|
||||
Ok(common)
|
||||
}
|
||||
|
||||
/// Insertion worker.
|
||||
fn insert_(
|
||||
&mut self,
|
||||
node_id: NodeId,
|
||||
tokens: &[u32],
|
||||
blocks: &[u32],
|
||||
) -> Result<usize, TrieError> {
|
||||
// TODO: in the future we may want to check that the blocks match for
|
||||
// the part of the prefix that is already in the trie to detect
|
||||
// mismatches.
|
||||
|
||||
assert_eq!(tokens.len(), blocks.len() * self.block_size);
|
||||
|
||||
let node_key = hash(&tokens[..self.block_size]);
|
||||
if let Some(&child_id) = self.nodes[node_id].children.get(&node_key) {
|
||||
self.update_access_time(child_id);
|
||||
let child = self
|
||||
.nodes
|
||||
.get_mut(child_id)
|
||||
// Unwrap here, since failure is a bug.
|
||||
.expect("Child node does not exist");
|
||||
let shared_prefix_len = shared_prefix(&child.key, tokens, self.block_size);
|
||||
|
||||
// We are done, the prefix is already in the trie.
|
||||
if shared_prefix_len == tokens.len() || shared_prefix_len == 0 {
|
||||
return Ok(shared_prefix_len);
|
||||
}
|
||||
|
||||
// The node's prefix is a prefix of the insertion prefix.
|
||||
if shared_prefix_len == child.key.len() {
|
||||
return Ok(shared_prefix_len
|
||||
+ self.insert_(
|
||||
child_id,
|
||||
&tokens[shared_prefix_len..],
|
||||
&blocks[shared_prefix_len / self.block_size..],
|
||||
)?);
|
||||
}
|
||||
|
||||
// The node's prefix and the insertion prefix only match partially,
|
||||
// split the node to just contain the matching part. Then insert the
|
||||
// remainder of the prefix into the node again
|
||||
let child_id = self.split_node(child_id, shared_prefix_len);
|
||||
let key = &tokens[shared_prefix_len..];
|
||||
let blocks = &blocks[shared_prefix_len / self.block_size..];
|
||||
Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?)
|
||||
} else {
|
||||
self.add_node(node_id, tokens, blocks);
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId {
|
||||
// We have to make the current node a child to ensure that its
|
||||
// properties and node id stay the same.
|
||||
|
||||
// This funcion unwraps, an invalid node_id is a programming error.
|
||||
|
||||
let node = self
|
||||
.nodes
|
||||
.get_mut(node_id)
|
||||
.expect("Node to-be split does not exist");
|
||||
let mut parent_key = node.key.split_off(prefix_len);
|
||||
let prefix_blocks = prefix_len / self.block_size;
|
||||
let mut parent_blocks = node.blocks.split_off(prefix_blocks);
|
||||
|
||||
// Move first part of the prefix to the parent. We swap to avoid
|
||||
// an allocation + copy for both splits of the key/blocks.
|
||||
std::mem::swap(&mut node.key, &mut parent_key);
|
||||
std::mem::swap(&mut node.blocks, &mut parent_blocks);
|
||||
|
||||
let node_key = hash(&node.key[..self.block_size]);
|
||||
|
||||
let grandparent_id = node.parent.expect("Node does not have a parent");
|
||||
let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks);
|
||||
self.add_node_to_parent(parent_id, node_key, node_id);
|
||||
|
||||
// Reborrow to make the borrow checker happy.
|
||||
let node = self
|
||||
.nodes
|
||||
.get_mut(node_id)
|
||||
.expect("Node to-be split does not exist");
|
||||
node.parent = Some(parent_id);
|
||||
|
||||
parent_id
|
||||
}
|
||||
|
||||
/// Create a node and add it to the parent.
|
||||
fn add_node(
|
||||
&mut self,
|
||||
parent_id: NodeId,
|
||||
key: impl Into<Vec<u32>>,
|
||||
blocks: impl Into<Vec<u32>>,
|
||||
) -> NodeId {
|
||||
let key = key.into();
|
||||
let blocks = blocks.into();
|
||||
let first = hash(&key[..self.block_size]);
|
||||
|
||||
let child = TrieNode::new(key, blocks, self.time, Some(parent_id));
|
||||
let child_id = self.nodes.insert(child);
|
||||
|
||||
self.add_node_to_parent(parent_id, first, child_id);
|
||||
self.leaves.insert((self.time, child_id));
|
||||
|
||||
child_id
|
||||
}
|
||||
|
||||
/// Add a node to the parent.
|
||||
fn add_node_to_parent(&mut self, parent_id: NodeId, hash: u64, child_id: NodeId) {
|
||||
// Unwrap here, passing in an unknown id is a programming error.
|
||||
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
||||
if parent.children.insert(hash, child_id).is_none() {
|
||||
// Only increase reference count if child does not replace another child.
|
||||
self.incref(parent_id)
|
||||
.expect("Failed to increase parent refcount");
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a node from the trie.
|
||||
fn remove_node(&mut self, node_id: NodeId) -> TrieNode {
|
||||
// Unwrap here, passing in an unknown id is a programming error.
|
||||
let node = self.nodes.remove(node_id).expect("Unknown node");
|
||||
assert!(
|
||||
node.children.is_empty(),
|
||||
"Tried to remove a node with {} children",
|
||||
node.children.len()
|
||||
);
|
||||
let parent_id = node.parent.expect("Attempted to remove root node");
|
||||
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
||||
|
||||
let node_key = hash(&node.key[..self.block_size]);
|
||||
parent.children.remove(&node_key);
|
||||
self.decref(parent_id)
|
||||
.expect("Failed to decrease parent refcount");
|
||||
node
|
||||
}
|
||||
|
||||
fn update_access_time(&mut self, node_id: NodeId) {
|
||||
// Unwrap here, passing in an unknown id is a programming error.
|
||||
let node = self.nodes.get_mut(node_id).expect("Unknown node");
|
||||
|
||||
// Update the ordered leaves set if the node is a leave.
|
||||
if self.leaves.remove(&(node.last_accessed, node_id)) {
|
||||
self.leaves.insert((self.time, node_id));
|
||||
}
|
||||
|
||||
node.last_accessed = self.time;
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[doc(hidden)]
|
||||
/// Print debugging output for the trie.
|
||||
///
|
||||
/// In contrast to `Debug` nicely formatted.
|
||||
pub fn print_debug(&self) {
|
||||
self.print_debug_(self.root, 0);
|
||||
}
|
||||
|
||||
fn print_debug_(&self, node_id: NodeId, indent: usize) {
|
||||
let node = &self.nodes[node_id];
|
||||
eprintln!(
|
||||
"{}{:?}, key: {:?}, blocks: {:?}, ref_count: {}, last_accessed: {}, parent: {:?}, children: {:?}",
|
||||
" ".repeat(indent),
|
||||
node_id,
|
||||
node.key,
|
||||
node.blocks,
|
||||
node.ref_count,
|
||||
node.last_accessed,
|
||||
node.parent,
|
||||
node.children
|
||||
);
|
||||
for child_id in self.nodes[node_id].children.values() {
|
||||
self.print_debug_(*child_id, indent + 2);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn root_id(&self) -> DefaultKey {
|
||||
self.root
|
||||
}
|
||||
}
|
||||
|
||||
/// Trie node.
|
||||
#[derive(Debug)]
|
||||
struct TrieNode {
|
||||
blocks: Vec<u32>,
|
||||
children: HashMap<u64, NodeId>,
|
||||
key: Vec<u32>,
|
||||
last_accessed: u64,
|
||||
parent: Option<NodeId>,
|
||||
ref_count: usize,
|
||||
}
|
||||
|
||||
impl TrieNode {
|
||||
fn new(key: Vec<u32>, blocks: Vec<u32>, last_accessed: u64, parent: Option<NodeId>) -> Self {
|
||||
TrieNode {
|
||||
children: HashMap::new(),
|
||||
key,
|
||||
blocks,
|
||||
last_accessed,
|
||||
parent,
|
||||
ref_count: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize {
|
||||
let full = left.iter().zip(right).take_while(|(a, b)| a == b).count();
|
||||
// NOTE: this is the case because the child node was chosen based on
|
||||
// matching the first character of the key/prefix.
|
||||
assert!(full > 0, "Prefixes must at least share 1 token");
|
||||
(full / block_size) * block_size
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn allocator_block_size() {
|
||||
let mut cache = RadixAllocator::new(2, 12, None);
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
|
||||
assert_eq!(allocation.prefix_len, 0);
|
||||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
|
||||
assert_eq!(allocation.prefix_len, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_block_size_non_aligned() {
|
||||
let mut cache = RadixAllocator::new(2, 12, None);
|
||||
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
|
||||
assert_eq!(allocation.prefix_len, 0);
|
||||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||
|
||||
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
|
||||
assert_eq!(allocation.prefix_len, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_reuses_prefixes() {
|
||||
let mut cache = RadixAllocator::new(1, 12, None);
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
||||
assert_eq!(allocation.blocks, allocation.slots);
|
||||
assert_eq!(allocation.prefix_len, 0);
|
||||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
||||
assert_eq!(allocation.prefix_len, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_collects_older_prefixes_first() {
|
||||
let mut cache = RadixAllocator::new(1, 7, None);
|
||||
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]);
|
||||
assert_eq!(allocation1.prefix_len, 0);
|
||||
|
||||
let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap();
|
||||
assert_eq!(allocation2.blocks, vec![1, 2]);
|
||||
assert_eq!(allocation2.prefix_len, 0);
|
||||
|
||||
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
||||
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
||||
|
||||
// We should get the blocks of the first allocation, since they are more recent.
|
||||
let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap();
|
||||
assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]);
|
||||
assert_eq!(allocation3.prefix_len, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_frees_fully_overlapping_prefills() {
|
||||
let mut cache = RadixAllocator::new(1, 10, None);
|
||||
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
|
||||
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
||||
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
||||
|
||||
let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation3.prefix_len, 4);
|
||||
|
||||
// 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks.
|
||||
assert_eq!(cache.free_blocks.len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_frees_partially_overlapping_prefills() {
|
||||
let mut cache = RadixAllocator::new(1, 20, None);
|
||||
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap();
|
||||
assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]);
|
||||
assert_eq!(allocation1.prefix_len, 0);
|
||||
|
||||
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
||||
|
||||
let allocation2 = cache
|
||||
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
|
||||
.unwrap();
|
||||
assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]);
|
||||
assert_eq!(allocation2.prefix_len, 2);
|
||||
|
||||
let allocation3 = cache
|
||||
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
|
||||
.unwrap();
|
||||
assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]);
|
||||
assert_eq!(allocation3.prefix_len, 2);
|
||||
|
||||
cache.free(allocation3.blocks.clone(), allocation3.allocation_id);
|
||||
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
||||
|
||||
// 20 blocks, of which 1 reserved for health checks, 6 for allocation3, 2 for allocation2.
|
||||
assert_eq!(cache.free_blocks.len(), 11);
|
||||
|
||||
let allocation4 = cache
|
||||
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
|
||||
.unwrap();
|
||||
assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]);
|
||||
assert_eq!(allocation4.prefix_len, 6);
|
||||
assert_eq!(cache.free_blocks.len(), 11);
|
||||
|
||||
let allocation5 = cache
|
||||
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
|
||||
.unwrap();
|
||||
assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]);
|
||||
assert_eq!(allocation5.prefix_len, 6);
|
||||
assert_eq!(cache.free_blocks.len(), 11);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trie_insertions_have_correct_prefix_len() {
|
||||
let mut trie = RadixTrie::new(1);
|
||||
|
||||
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0);
|
||||
|
||||
// Already exists.
|
||||
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 3);
|
||||
|
||||
// Completely new at root-level
|
||||
assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(), 0);
|
||||
|
||||
// Contains full prefix, but longer.
|
||||
assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(), 3);
|
||||
|
||||
// Shares partial prefix, we need a split.
|
||||
assert_eq!(
|
||||
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
||||
.unwrap(),
|
||||
4
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trie_insertions_block_size() {
|
||||
let mut trie = RadixTrie::new(2);
|
||||
|
||||
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 0);
|
||||
|
||||
// Already exists.
|
||||
// But needs to be block_size aligned
|
||||
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 4);
|
||||
|
||||
// Completely new at root-level
|
||||
assert_eq!(trie.insert(&[1, 2, 3, 4], &[1, 2]).unwrap(), 0);
|
||||
|
||||
// Contains full prefix, but longer.
|
||||
assert_eq!(trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2]).unwrap(), 4);
|
||||
|
||||
// Shares partial prefix, we need a split.
|
||||
assert_eq!(
|
||||
trie.insert(&[0, 1, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3])
|
||||
.unwrap(),
|
||||
2
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trie_get_returns_correct_blocks() {
|
||||
let mut trie = RadixTrie::new(1);
|
||||
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
||||
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
|
||||
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
|
||||
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
||||
.unwrap();
|
||||
|
||||
let mut blocks = Vec::new();
|
||||
trie.find(&[0], &mut blocks);
|
||||
assert_eq!(blocks, vec![0]);
|
||||
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2]);
|
||||
|
||||
blocks.clear();
|
||||
trie.find(&[1, 2, 3], &mut blocks);
|
||||
assert_eq!(blocks, vec![1, 2, 3]);
|
||||
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2, 3], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2, 3]);
|
||||
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2, 3, 4]);
|
||||
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2, 3, 5], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2, 3, 5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trie_evict_removes_correct_blocks() {
|
||||
let mut trie = RadixTrie::new(1);
|
||||
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
||||
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
||||
.unwrap();
|
||||
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
|
||||
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
|
||||
|
||||
let mut blocks = Vec::new();
|
||||
|
||||
// Remove less than the leave blocks.
|
||||
assert_eq!(trie.evict(1), vec![7]);
|
||||
trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]);
|
||||
|
||||
// Refresh other leaf.
|
||||
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
|
||||
trie.find(&[1, 2, 3], &mut blocks);
|
||||
|
||||
// Remove the leave blocks exactly.
|
||||
assert_eq!(trie.evict(2), vec![5, 6]);
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2, 3]);
|
||||
|
||||
trie.find(&[1, 2, 3], &mut blocks);
|
||||
|
||||
// Remove more than the leave blocks.
|
||||
assert_eq!(trie.evict(3), vec![4, 3, 2]);
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1]);
|
||||
|
||||
// Clear out the whole trie.
|
||||
assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]);
|
||||
}
|
||||
}
|
@ -16,16 +16,15 @@ path = "src/main.rs"
|
||||
[dependencies]
|
||||
average = "0.14"
|
||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||
crossterm = "0.27"
|
||||
float-ord = "0.3.2"
|
||||
serde = {version = "1.0.188", features = ["derive"]}
|
||||
serde_json = "1.0"
|
||||
tabled = "0.14.0"
|
||||
text-generation-client = { path = "../router/client" }
|
||||
text-generation-client = { path = "../backends/client" }
|
||||
thiserror = "1.0.48"
|
||||
tokenizers = { workspace = true }
|
||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
|
||||
tui = {package = "ratatui", version = "0.23", default-features = false, features = ["crossterm"]}
|
||||
ratatui = "0.28.1"
|
||||
tracing = "0.1.37"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||
hf-hub = { workspace = true }
|
||||
|
@ -7,7 +7,7 @@
|
||||
</div>
|
||||
|
||||
A lightweight benchmarking tool based inspired by [oha](https://github.com/hatoo/oha)
|
||||
and powered by [tui](https://github.com/tui-rs-revival/ratatui).
|
||||
and powered by [Ratatui](https://github.com/ratatui/ratatui).
|
||||
|
||||
## Install
|
||||
|
||||
|
@ -1,16 +1,15 @@
|
||||
/// Inspired by https://github.com/hatoo/oha/blob/bb989ea3cd77727e7743e7daa60a19894bb5e901/src/monitor.rs
|
||||
use crate::generation::{Decode, Message, Prefill};
|
||||
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
||||
use text_generation_client::ClientError;
|
||||
use tokio::sync::mpsc;
|
||||
use tui::backend::Backend;
|
||||
use tui::layout::{Alignment, Constraint, Direction, Layout};
|
||||
use tui::style::{Color, Modifier, Style};
|
||||
use tui::text::{Line, Span};
|
||||
use tui::widgets::{
|
||||
use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
||||
use ratatui::layout::{Alignment, Constraint, Direction, Layout};
|
||||
use ratatui::style::{Color, Modifier, Style};
|
||||
use ratatui::text::{Line, Span};
|
||||
use ratatui::widgets::{
|
||||
Axis, BarChart, Block, Borders, Chart, Dataset, Gauge, GraphType, Paragraph, Tabs,
|
||||
};
|
||||
use tui::{symbols, Frame};
|
||||
use ratatui::{symbols, Frame};
|
||||
use text_generation_client::ClientError;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// TUI powered App
|
||||
pub(crate) struct App {
|
||||
@ -153,7 +152,7 @@ impl App {
|
||||
}
|
||||
|
||||
/// Render frame
|
||||
pub fn render<B: Backend>(&mut self, f: &mut Frame<'_, B>) {
|
||||
pub fn render(&mut self, f: &mut Frame) {
|
||||
let batch_progress =
|
||||
(self.completed_batch as f64 / self.data.batch_size.len() as f64).clamp(0.0, 1.0);
|
||||
let run_progress =
|
||||
@ -172,7 +171,7 @@ impl App {
|
||||
]
|
||||
.as_ref(),
|
||||
)
|
||||
.split(f.size());
|
||||
.split(f.area());
|
||||
|
||||
// Top row horizontal layout
|
||||
let top = Layout::default()
|
||||
@ -239,7 +238,7 @@ impl App {
|
||||
f.render_widget(helper, row5[0]);
|
||||
|
||||
// Batch tabs
|
||||
let titles = self
|
||||
let titles: Vec<Line> = self
|
||||
.data
|
||||
.batch_size
|
||||
.iter()
|
||||
@ -497,7 +496,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec<Line<'a>> {
|
||||
"Lowest: {:.2} {unit}",
|
||||
data.iter()
|
||||
.min_by(|a, b| a.total_cmp(b))
|
||||
.unwrap_or(&std::f64::NAN)
|
||||
.unwrap_or(&f64::NAN)
|
||||
),
|
||||
Style::default().fg(Color::Reset),
|
||||
)]),
|
||||
@ -506,7 +505,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec<Line<'a>> {
|
||||
"Highest: {:.2} {unit}",
|
||||
data.iter()
|
||||
.max_by(|a, b| a.total_cmp(b))
|
||||
.unwrap_or(&std::f64::NAN)
|
||||
.unwrap_or(&f64::NAN)
|
||||
),
|
||||
Style::default().fg(Color::Reset),
|
||||
)]),
|
||||
@ -555,17 +554,17 @@ fn latency_throughput_chart<'a>(
|
||||
let min_latency: f64 = *latency_iter
|
||||
.clone()
|
||||
.min_by(|a, b| a.total_cmp(b))
|
||||
.unwrap_or(&std::f64::NAN);
|
||||
.unwrap_or(&f64::NAN);
|
||||
let max_latency: f64 = *latency_iter
|
||||
.max_by(|a, b| a.total_cmp(b))
|
||||
.unwrap_or(&std::f64::NAN);
|
||||
.unwrap_or(&f64::NAN);
|
||||
let min_throughput: f64 = *throughput_iter
|
||||
.clone()
|
||||
.min_by(|a, b| a.total_cmp(b))
|
||||
.unwrap_or(&std::f64::NAN);
|
||||
.unwrap_or(&f64::NAN);
|
||||
let max_throughput: f64 = *throughput_iter
|
||||
.max_by(|a, b| a.total_cmp(b))
|
||||
.unwrap_or(&std::f64::NAN);
|
||||
.unwrap_or(&f64::NAN);
|
||||
|
||||
// Char min max values
|
||||
let min_x = if zoom {
|
||||
|
@ -1,5 +1,5 @@
|
||||
/// Inspired by https://github.com/orhun/rust-tui-template/blob/472aa515119d4c94903eac12d9784417281dc7f5/src/event.rs
|
||||
use crossterm::event;
|
||||
use ratatui::crossterm::event;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::{broadcast, mpsc};
|
||||
|
||||
|
@ -1,8 +1,9 @@
|
||||
use std::time::{Duration, Instant};
|
||||
use text_generation_client::{
|
||||
Batch, CachedBatch, ClientError, NextTokenChooserParameters, Request, ShardedClient,
|
||||
use text_generation_client::v3::{
|
||||
Batch, CachedBatch, NextTokenChooserParameters, Request, ShardedClient,
|
||||
StoppingCriteriaParameters,
|
||||
};
|
||||
use text_generation_client::{Chunk, ClientError, Input};
|
||||
use tokenizers::{Tokenizer, TruncationDirection};
|
||||
use tokio::sync::{broadcast, mpsc};
|
||||
|
||||
@ -142,8 +143,12 @@ async fn prefill(
|
||||
.map(|id| Request {
|
||||
id: id.into(),
|
||||
prefill_logprobs: false,
|
||||
input_chunks: Some(Input {
|
||||
chunks: vec![Chunk::Text(sequence.clone()).into()],
|
||||
}),
|
||||
inputs: sequence.clone(),
|
||||
truncate: sequence_length,
|
||||
add_special_tokens: true,
|
||||
parameters: Some(parameters.clone()),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: decode_length,
|
||||
@ -151,6 +156,10 @@ async fn prefill(
|
||||
ignore_eos_token: true, // Will not stop even if a eos token is generated
|
||||
}),
|
||||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
prefix_len: 0,
|
||||
adapter_id: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
@ -159,15 +168,13 @@ async fn prefill(
|
||||
requests,
|
||||
size: batch_size,
|
||||
max_tokens: batch_size * (sequence_length + decode_length),
|
||||
max_blocks: 0,
|
||||
};
|
||||
|
||||
// Run prefill
|
||||
let start_time = Instant::now();
|
||||
|
||||
let (_, decode_batch, _) = client.prefill(batch.clone()).await?;
|
||||
|
||||
let (_, decode_batch, _) = client.decode(vec![decode_batch.clone().unwrap()]).await?;
|
||||
|
||||
// Get latency
|
||||
let latency = start_time.elapsed();
|
||||
|
||||
@ -183,12 +190,11 @@ async fn prefill(
|
||||
};
|
||||
|
||||
Ok((step, decode_batch))
|
||||
|
||||
}
|
||||
|
||||
/// Run a full decode
|
||||
async fn decode(batch: CachedBatch, client: &mut ShardedClient) -> Result<Decode, ClientError> {
|
||||
let mut decode_length = 1; // 1 decode step was already scheduled in prefill with speculative scheduling
|
||||
let mut decode_length = 0;
|
||||
let batch_size = batch.size;
|
||||
|
||||
let start_time = Instant::now();
|
||||
|
@ -6,13 +6,13 @@ mod utils;
|
||||
|
||||
use crate::app::App;
|
||||
use crate::event::Event;
|
||||
use crossterm::ExecutableCommand;
|
||||
use ratatui::backend::CrosstermBackend;
|
||||
use ratatui::crossterm::ExecutableCommand;
|
||||
use ratatui::Terminal;
|
||||
use std::io;
|
||||
use text_generation_client::{GrammarType, NextTokenChooserParameters, ShardedClient};
|
||||
use text_generation_client::v3::{GrammarType, NextTokenChooserParameters, ShardedClient};
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::sync::{broadcast, mpsc};
|
||||
use tui::backend::CrosstermBackend;
|
||||
use tui::Terminal;
|
||||
|
||||
/// Run benchmarking app
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@ -50,9 +50,9 @@ pub async fn run(
|
||||
};
|
||||
|
||||
// Initialize terminal properties
|
||||
crossterm::terminal::enable_raw_mode()?;
|
||||
io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?;
|
||||
io::stdout().execute(crossterm::cursor::Hide)?;
|
||||
ratatui::crossterm::terminal::enable_raw_mode()?;
|
||||
io::stdout().execute(ratatui::crossterm::terminal::EnterAlternateScreen)?;
|
||||
io::stdout().execute(ratatui::crossterm::cursor::Hide)?;
|
||||
|
||||
// Initialize terminal
|
||||
let mut terminal = {
|
||||
@ -128,9 +128,9 @@ pub async fn run(
|
||||
let _ = shutdown_guard_receiver.recv().await;
|
||||
|
||||
// Revert terminal to original view
|
||||
io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?;
|
||||
crossterm::terminal::disable_raw_mode()?;
|
||||
io::stdout().execute(crossterm::cursor::Show)?;
|
||||
io::stdout().execute(ratatui::crossterm::terminal::LeaveAlternateScreen)?;
|
||||
ratatui::crossterm::terminal::disable_raw_mode()?;
|
||||
io::stdout().execute(ratatui::crossterm::cursor::Show)?;
|
||||
|
||||
let parameters_table = table::parameters_table(
|
||||
tokenizer_name,
|
||||
|
@ -4,7 +4,7 @@
|
||||
/// and: https://github.com/orhun/rust-tui-template
|
||||
use clap::Parser;
|
||||
use std::path::Path;
|
||||
use text_generation_client::ShardedClient;
|
||||
use text_generation_client::v3::ShardedClient;
|
||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
@ -51,7 +51,7 @@ struct Args {
|
||||
runs: usize,
|
||||
|
||||
/// Number of warmup cycles
|
||||
#[clap(default_value = "3", short, long, env)]
|
||||
#[clap(default_value = "1", short, long, env)]
|
||||
warmups: usize,
|
||||
|
||||
/// The location of the grpc socket. This benchmark tool bypasses the router
|
||||
@ -155,7 +155,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// We need to download it outside of the Tokio runtime
|
||||
let params = FromPretrainedParameters {
|
||||
revision,
|
||||
token: auth_token,
|
||||
auth_token,
|
||||
..Default::default()
|
||||
};
|
||||
Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).unwrap()
|
||||
|
@ -156,17 +156,17 @@ fn avg_min_max(data: &[f64]) -> (f64, f64, f64) {
|
||||
let min = data
|
||||
.iter()
|
||||
.min_by(|a, b| a.total_cmp(b))
|
||||
.unwrap_or(&std::f64::NAN);
|
||||
.unwrap_or(&f64::NAN);
|
||||
let max = data
|
||||
.iter()
|
||||
.max_by(|a, b| a.total_cmp(b))
|
||||
.unwrap_or(&std::f64::NAN);
|
||||
.unwrap_or(&f64::NAN);
|
||||
(average, *min, *max)
|
||||
}
|
||||
|
||||
fn px(data: &[f64], p: u32) -> f64 {
|
||||
let i = (f64::from(p) / 100.0 * data.len() as f64) as usize;
|
||||
*data.get(i).unwrap_or(&std::f64::NAN)
|
||||
*data.get(i).unwrap_or(&f64::NAN)
|
||||
}
|
||||
|
||||
fn format_value(value: f64, unit: &'static str) -> String {
|
||||
|
@ -37,7 +37,7 @@ pub(crate) fn percentiles(values: &[f64], pecents: &[i32]) -> BTreeMap<String, f
|
||||
.iter()
|
||||
.map(|&p| {
|
||||
let i = (f64::from(p) / 100.0 * values.len() as f64) as usize;
|
||||
(format!("p{p}"), *values.get(i).unwrap_or(&std::f64::NAN))
|
||||
(format!("p{p}"), *values.get(i).unwrap_or(&f64::NAN))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
@ -1,3 +1,6 @@
|
||||
# Legacy warning ⚠️
|
||||
The inference clients from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference) are recommended over `text_generation`.
|
||||
|
||||
# Text Generation
|
||||
|
||||
The Hugging Face Text Generation Python library provides a convenient way of interfacing with a
|
||||
|
@ -12,12 +12,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "0.6.0"
|
||||
__version__ = "0.7.0"
|
||||
|
||||
DEPRECATION_WARNING = (
|
||||
"`text_generation` clients are deprecated and will be removed in the near future. "
|
||||
"Please use the `InferenceClient` from the `huggingface_hub` package instead."
|
||||
)
|
||||
|
||||
from text_generation.client import Client, AsyncClient
|
||||
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient
|
||||
from text_generation.client import Client, AsyncClient # noqa E402
|
||||
from text_generation.inference_api import ( # noqa E402
|
||||
InferenceAPIClient,
|
||||
InferenceAPIAsyncClient,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Client",
|
||||
"AsyncClient",
|
||||
"InferenceAPIClient",
|
||||
"InferenceAPIAsyncClient",
|
||||
]
|
||||
|
@ -757,7 +757,12 @@ class AsyncClient:
|
||||
continue
|
||||
payload = byte_payload.decode("utf-8")
|
||||
if payload.startswith("data:"):
|
||||
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
||||
payload_data = (
|
||||
payload.lstrip("data:").rstrip("\n").removeprefix(" ")
|
||||
)
|
||||
if payload_data == "[DONE]":
|
||||
break
|
||||
json_payload = json.loads(payload_data)
|
||||
try:
|
||||
response = ChatCompletionChunk(**json_payload)
|
||||
yield response
|
||||
|
@ -21,7 +21,7 @@ def deployed_models(headers: Optional[Dict] = None) -> List[DeployedModel]:
|
||||
List[DeployedModel]: list of all currently deployed models
|
||||
"""
|
||||
resp = requests.get(
|
||||
f"https://api-inference.huggingface.co/framework/text-generation-inference",
|
||||
"https://api-inference.huggingface.co/framework/text-generation-inference",
|
||||
headers=headers,
|
||||
timeout=5,
|
||||
)
|
||||
|
@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import BaseModel, field_validator, ConfigDict
|
||||
from typing import Optional, List, Union, Any
|
||||
|
||||
from text_generation.errors import ValidationError
|
||||
@ -28,11 +28,17 @@ class ToolCall(BaseModel):
|
||||
function: dict
|
||||
|
||||
|
||||
class Chunk(BaseModel):
|
||||
type: str
|
||||
text: Optional[str] = None
|
||||
image_url: Any = None
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
# Role of the message sender
|
||||
role: str
|
||||
# Content of the message
|
||||
content: Optional[str] = None
|
||||
content: Optional[Union[str, List[Chunk]]] = None
|
||||
# Optional name of the message sender
|
||||
name: Optional[str] = None
|
||||
# Tool calls associated with the chat completion
|
||||
@ -61,7 +67,7 @@ class ChoiceDeltaToolCall(BaseModel):
|
||||
class ChoiceDelta(BaseModel):
|
||||
role: str
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[ChoiceDeltaToolCall]
|
||||
tool_calls: Optional[ChoiceDeltaToolCall] = None
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
@ -168,7 +174,7 @@ class ChatCompletionComplete(BaseModel):
|
||||
# Log probabilities for the chat completion
|
||||
logprobs: Optional[Any]
|
||||
# Reason for completion
|
||||
finish_reason: str
|
||||
finish_reason: Optional[str]
|
||||
# Usage details of the chat completion
|
||||
usage: Optional[Any] = None
|
||||
|
||||
@ -191,6 +197,7 @@ class ChatCompletionChunk(BaseModel):
|
||||
model: str
|
||||
system_fingerprint: str
|
||||
choices: List[Choice]
|
||||
usage: Optional[Any] = None
|
||||
|
||||
|
||||
class Parameters(BaseModel):
|
||||
@ -452,5 +459,9 @@ class StreamResponse(BaseModel):
|
||||
|
||||
# Inference API currently deployed model
|
||||
class DeployedModel(BaseModel):
|
||||
# Disable warning for use of `model_` prefix in `model_id`. Be mindful about adding members
|
||||
# with model_ prefixes, since this disables guardrails for colliding fields:
|
||||
# https://github.com/pydantic/pydantic/issues/9177
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
model_id: str
|
||||
sha: str
|
||||
|
10
docs/README.md
Normal file
10
docs/README.md
Normal file
@ -0,0 +1,10 @@
|
||||
Documentation available at: https://huggingface.co/docs/text-generation-inference
|
||||
|
||||
## Release
|
||||
|
||||
When making a release, please update the latest version in the documentation with:
|
||||
```
|
||||
export OLD_VERSION="2\.0\.3"
|
||||
export NEW_VERSION="2\.0\.4"
|
||||
find . -name '*.md' -exec sed -i -e "s/$OLD_VERSION/$NEW_VERSION/g" {} \;
|
||||
```
|
File diff suppressed because it is too large
Load Diff
@ -11,12 +11,16 @@
|
||||
title: Using TGI with Intel Gaudi
|
||||
- local: installation_inferentia
|
||||
title: Using TGI with AWS Inferentia
|
||||
- local: installation_intel
|
||||
title: Using TGI with Intel GPUs
|
||||
- local: installation
|
||||
title: Installation from source
|
||||
- local: supported_models
|
||||
title: Supported Models and Hardware
|
||||
- local: messages_api
|
||||
title: Messages API
|
||||
- local: architecture
|
||||
title: Internal Architecture
|
||||
- local: usage_statistics
|
||||
title: Usage Statistics
|
||||
title: Getting started
|
||||
- sections:
|
||||
- local: basic_tutorials/consuming_tgi
|
||||
@ -27,8 +31,6 @@
|
||||
title: Serving Private & Gated Models
|
||||
- local: basic_tutorials/using_cli
|
||||
title: Using TGI CLI
|
||||
- local: basic_tutorials/launcher
|
||||
title: All TGI CLI options
|
||||
- local: basic_tutorials/non_core_models
|
||||
title: Non-core Model Serving
|
||||
- local: basic_tutorials/safety
|
||||
@ -42,6 +44,14 @@
|
||||
- local: basic_tutorials/train_medusa
|
||||
title: Train Medusa
|
||||
title: Tutorials
|
||||
- sections:
|
||||
- local: reference/launcher
|
||||
title: All TGI CLI options
|
||||
- local: reference/metrics
|
||||
title: Exported Metrics
|
||||
- local: reference/api_reference
|
||||
title: API Reference
|
||||
title: Reference
|
||||
- sections:
|
||||
- local: conceptual/streaming
|
||||
title: Streaming
|
||||
@ -59,5 +69,10 @@
|
||||
title: Speculation (Medusa, ngram)
|
||||
- local: conceptual/guidance
|
||||
title: How Guidance Works (via outlines)
|
||||
- local: conceptual/lora
|
||||
title: LoRA (Low-Rank Adaptation)
|
||||
- local: conceptual/external
|
||||
title: External Resources
|
||||
|
||||
|
||||
title: Conceptual Guides
|
||||
|
232
docs/source/architecture.md
Normal file
232
docs/source/architecture.md
Normal file
@ -0,0 +1,232 @@
|
||||
# Text Generation Inference Architecture
|
||||
|
||||
This document aims at describing the architecture of Text Generation Inference (TGI), by describing the call flow between the separate components.
|
||||
|
||||
A high-level architecture diagram can be seen here:
|
||||
|
||||

|
||||
|
||||
This diagram shows well there are these separate components:
|
||||
|
||||
- **The router**, also named `webserver`, that receives the client requests, buffers them, creates some batches, and prepares gRPC calls to a model server.
|
||||
- **The model server**, responsible of receiving the gRPC requests and to process the inference on the model. If the model is sharded across multiple accelerators (e.g.: multiple GPUs), the model server shards might be synchronized via NCCL or equivalent.
|
||||
- **The launcher** is a helper that will be able to launch one or several model servers (if model is sharded), and it launches the router with the compatible arguments.
|
||||
|
||||
The router and the model server can be two different machines, they do not need to be deployed together.
|
||||
|
||||
## The Router
|
||||
|
||||
This component is a rust web server binary that accepts HTTP requests using the custom [HTTP API](https://huggingface.github.io/text-generation-inference/), as well as OpenAI's [Messages API](https://huggingface.co/docs/text-generation-inference/messages_api).
|
||||
The router receives the API calls and handles the "baches" logic (and introduction to batching can be found [here](https://github.com/huggingface/text-generation-inference/blob/main/router/README.md)).
|
||||
It uses different strategies to reduce latency between requests and responses, especially oriented to decoding latency. It will use queues, schedulers, and block allocators to achieve that and produce batched requests that it will then be sent to the model server.
|
||||
|
||||
### Router's command line
|
||||
|
||||
The router command line will be the way to pass parameters to it (it does not rely on configuration file):
|
||||
|
||||
```
|
||||
Text Generation Webserver
|
||||
|
||||
Usage: text-generation-router [OPTIONS]
|
||||
|
||||
Options:
|
||||
--max-concurrent-requests <MAX_CONCURRENT_REQUESTS>
|
||||
[env: MAX_CONCURRENT_REQUESTS=] [default: 128]
|
||||
--max-best-of <MAX_BEST_OF>
|
||||
[env: MAX_BEST_OF=] [default: 2]
|
||||
--max-stop-sequences <MAX_STOP_SEQUENCES>
|
||||
[env: MAX_STOP_SEQUENCES=] [default: 4]
|
||||
--max-top-n-tokens <MAX_TOP_N_TOKENS>
|
||||
[env: MAX_TOP_N_TOKENS=] [default: 5]
|
||||
--max-input-tokens <MAX_INPUT_TOKENS>
|
||||
[env: MAX_INPUT_TOKENS=] [default: 1024]
|
||||
--max-total-tokens <MAX_TOTAL_TOKENS>
|
||||
[env: MAX_TOTAL_TOKENS=] [default: 2048]
|
||||
--waiting-served-ratio <WAITING_SERVED_RATIO>
|
||||
[env: WAITING_SERVED_RATIO=] [default: 1.2]
|
||||
--max-batch-prefill-tokens <MAX_BATCH_PREFILL_TOKENS>
|
||||
[env: MAX_BATCH_PREFILL_TOKENS=] [default: 4096]
|
||||
--max-batch-total-tokens <MAX_BATCH_TOTAL_TOKENS>
|
||||
[env: MAX_BATCH_TOTAL_TOKENS=]
|
||||
--max-waiting-tokens <MAX_WAITING_TOKENS>
|
||||
[env: MAX_WAITING_TOKENS=] [default: 20]
|
||||
--max-batch-size <MAX_BATCH_SIZE>
|
||||
[env: MAX_BATCH_SIZE=]
|
||||
--hostname <HOSTNAME>
|
||||
[env: HOSTNAME=] [default: 0.0.0.0]
|
||||
-p, --port <PORT>
|
||||
[env: PORT=] [default: 3000]
|
||||
--master-shard-uds-path <MASTER_SHARD_UDS_PATH>
|
||||
[env: MASTER_SHARD_UDS_PATH=] [default: /tmp/text-generation-server-0]
|
||||
--tokenizer-name <TOKENIZER_NAME>
|
||||
[env: TOKENIZER_NAME=] [default: bigscience/bloom]
|
||||
--tokenizer-config-path <TOKENIZER_CONFIG_PATH>
|
||||
[env: TOKENIZER_CONFIG_PATH=]
|
||||
--revision <REVISION>
|
||||
[env: REVISION=]
|
||||
--validation-workers <VALIDATION_WORKERS>
|
||||
[env: VALIDATION_WORKERS=] [default: 2]
|
||||
--json-output
|
||||
[env: JSON_OUTPUT=]
|
||||
--otlp-endpoint <OTLP_ENDPOINT>
|
||||
[env: OTLP_ENDPOINT=]
|
||||
--otlp-service-name <OTLP_SERVICE_NAME>
|
||||
[env: OTLP_SERVICE_NAME=]
|
||||
--cors-allow-origin <CORS_ALLOW_ORIGIN>
|
||||
[env: CORS_ALLOW_ORIGIN=]
|
||||
--ngrok
|
||||
[env: NGROK=]
|
||||
--ngrok-authtoken <NGROK_AUTHTOKEN>
|
||||
[env: NGROK_AUTHTOKEN=]
|
||||
--ngrok-edge <NGROK_EDGE>
|
||||
[env: NGROK_EDGE=]
|
||||
--messages-api-enabled
|
||||
[env: MESSAGES_API_ENABLED=]
|
||||
--disable-grammar-support
|
||||
[env: DISABLE_GRAMMAR_SUPPORT=]
|
||||
--max-client-batch-size <MAX_CLIENT_BATCH_SIZE>
|
||||
[env: MAX_CLIENT_BATCH_SIZE=] [default: 4]
|
||||
-h, --help
|
||||
Print help
|
||||
-V, --version
|
||||
Print version
|
||||
```
|
||||
|
||||
## The Model Server
|
||||
|
||||
The model server is a python server, capable of starting a server waiting for gRPC requests, loads a given model, perform sharding to provide [tensor parallelism](https://huggingface.co/docs/text-generation-inference/conceptual/tensor_parallelism), and stays alive while waiting for new requests.
|
||||
The model server supports models instantiated using Pytorch and optimized for inference mainly on CUDA/ROCM.
|
||||
|
||||
### Model Server Variants
|
||||
|
||||
Several variants of the model server exist that are actively supported by Hugging Face:
|
||||
|
||||
- By default, the model server will attempt building [a server optimized for Nvidia GPUs with CUDA](https://huggingface.co/docs/text-generation-inference/installation_nvidia). The code for this version is hosted in the [main TGI repository](https://github.com/huggingface/text-generation-inference).
|
||||
- A [version optimized for AMD with ROCm](https://huggingface.co/docs/text-generation-inference/installation_amd) is hosted in the main TGI repository. Some model features differ.
|
||||
- A [version optimized for Intel GPUs](https://huggingface.co/docs/text-generation-inference/installation_intel) is hosted in the main TGI repository. Some model features differ.
|
||||
- The [version for Intel Gaudi](https://huggingface.co/docs/text-generation-inference/installation_gaudi) is maintained on a forked repository, often resynchronized with the main [TGI repository](https://github.com/huggingface/tgi-gaudi).
|
||||
- A [version for Neuron (AWS Inferentia2)](https://huggingface.co/docs/text-generation-inference/installation_inferentia) is maintained as part of [Optimum Neuron](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference).
|
||||
- A version for Google TPUs is maintained as part of [Optimum TPU](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference).
|
||||
|
||||
Not all variants provide the same features, as hardware and middleware capabilities do not provide the same optimizations.
|
||||
|
||||
### Command Line Interface
|
||||
|
||||
The official command line interface (CLI) for the server supports three subcommands, `download-weights`, `quantize` and `serve`:
|
||||
|
||||
- `download-weights` will download weights from the hub and, in some variants it will convert weights to a format that is adapted to the given implementation;
|
||||
- `quantize` will allow to quantize a model using the `qptq` package. This feature is not available nor supported on all variants;
|
||||
- `serve` will start the server that load a model (or a model shard), receives gRPC calls from the router, performs an inference and provides a formatted response to the given request.
|
||||
|
||||
Serve's command line parameters on the TGI repository are these:
|
||||
|
||||
```
|
||||
Usage: cli.py serve [OPTIONS] MODEL_ID
|
||||
|
||||
╭─ Arguments ──────────────────────────────────────────────────────────────────────────────────────────────╮
|
||||
│ * model_id TEXT [default: None] [required] │
|
||||
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯
|
||||
╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────╮
|
||||
│ --revision TEXT [default: None] │
|
||||
│ --sharded --no-sharded [default: no-sharded] │
|
||||
│ --quantize [bitsandbytes|bitsandbytes [default: None] │
|
||||
│ -nf4|bitsandbytes-fp4|gptq │
|
||||
│ |awq|eetq|exl2|fp8] │
|
||||
│ --speculate INTEGER [default: None] │
|
||||
│ --dtype [float16|bfloat16] [default: None] │
|
||||
│ --trust-remote-code --no-trust-remote-code [default: │
|
||||
│ no-trust-remote-code] │
|
||||
│ --uds-path PATH [default: │
|
||||
│ /tmp/text-generation-serve… │
|
||||
│ --logger-level TEXT [default: INFO] │
|
||||
│ --json-output --no-json-output [default: no-json-output] │
|
||||
│ --otlp-endpoint TEXT [default: None] │
|
||||
│ --otlp-service-name TEXT [default: │
|
||||
│ text-generation-inference...│
|
||||
│ --help Show this message and exit. │
|
||||
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯
|
||||
```
|
||||
|
||||
Note that some variants might support different parameters, and they could possibly accept more options that can be passed on using environment variables.
|
||||
|
||||
## Call Flow
|
||||
|
||||
Once both components are initialized, weights downloaded and model server is up and running, router and model server exchange data and info through the gRPC call. There are currently two supported schemas, [v2](https://github.com/huggingface/text-generation-inference/blob/main/proto/generate.proto) and [v3](https://github.com/huggingface/text-generation-inference/blob/main/proto/v3/generate.proto). These two versions are almost identical, except for:
|
||||
|
||||
- input chunks support, for text and image data,
|
||||
- paged attention support
|
||||
|
||||
Here's a diagram that displays the exchanges that follow the router and model server startup.
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
|
||||
Router->>Model Server: service discovery
|
||||
Model Server-->>Router: urls for other shards
|
||||
|
||||
Router->>Model Server: get model info
|
||||
Model Server-->>Router: shard info
|
||||
|
||||
Router->>Model Server: health check
|
||||
Model Server-->>Router: health OK
|
||||
|
||||
Router->>Model Server: warmup(max_input_tokens, max_batch_prefill_tokens, max_total_tokens, max_batch_size)
|
||||
Model Server-->>Router: warmup result
|
||||
```
|
||||
|
||||
After these are done, the router is ready to receive generate calls from multiple clients. Here's an example.
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Client 1
|
||||
participant Client 2
|
||||
participant Client 3
|
||||
participant Router
|
||||
participant Model Server
|
||||
|
||||
Client 1->>Router: generate_stream
|
||||
Router->>Model Server: prefill(batch1)
|
||||
Model Server-->>Router: generations, cached_batch1, timings
|
||||
Router-->>Client 1: token 1
|
||||
|
||||
Router->>Model Server: decode(cached_batch1)
|
||||
Model Server-->>Router: generations, cached_batch1, timings
|
||||
Router-->>Client 1: token 2
|
||||
|
||||
Router->>Model Server: decode(cached_batch1)
|
||||
Model Server-->>Router: generations, cached_batch1, timings
|
||||
Router-->>Client 1: token 3
|
||||
|
||||
Client 2->>Router: generate_stream
|
||||
Router->>Model Server: prefill(batch2)
|
||||
Note right of Model Server: This stops previous batch, that is restarted
|
||||
Model Server-->>Router: generations, cached_batch2, timings
|
||||
Router-->>Client 2: token 1'
|
||||
|
||||
Router->>Model Server: decode(cached_batch1, cached_batch2)
|
||||
Model Server-->>Router: generations, cached_batch1, timings
|
||||
Router-->>Client 1: token 4
|
||||
Router-->>Client 2: token 2'
|
||||
|
||||
Note left of Client 1: Client 1 leaves
|
||||
Router->>Model Server: filter_batch(cached_batch1, request_ids_to_keep=batch2)
|
||||
Model Server-->>Router: filtered batch
|
||||
|
||||
Router->>Model Server: decode(cached_batch2)
|
||||
Model Server-->>Router: generations, cached_batch2, timings
|
||||
Router-->>Client 2: token 3'
|
||||
|
||||
Client 3->>Router: generate_stream
|
||||
Note right of Model Server: This stops previous batch, that is restarted
|
||||
Router->>Model Server: prefill(batch3)
|
||||
Note left of Client 1: Client 3 leaves without receiving any batch
|
||||
Router->>Model Server: clear_cache(batch3)
|
||||
Note right of Model Server: This stops previous batch, that is restarted
|
||||
|
||||
Router->>Model Server: decode(cached_batch3)
|
||||
Note right of Model Server: Last token (stopping criteria)
|
||||
Model Server-->>Router: generations, cached_batch3, timings
|
||||
Router-->>Client 2: token 4'
|
||||
|
||||
|
||||
```
|
@ -1,81 +1,125 @@
|
||||
# Consuming Text Generation Inference
|
||||
|
||||
There are many ways you can consume Text Generation Inference server in your applications. After launching, you can use the `/generate` route and make a `POST` request to get results from the server. You can also use the `/generate_stream` route if you want TGI to return a stream of tokens. You can make the requests using the tool of your preference, such as curl, Python or TypeScrpt. For a final end-to-end experience, we also open-sourced ChatUI, a chat interface for open-source models.
|
||||
There are many ways to consume Text Generation Inference (TGI) server in your applications. After launching the server, you can use the [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) `/v1/chat/completions` route and make a `POST` request to get results from the server. You can also pass `"stream": true` to the call if you want TGI to return a stream of tokens.
|
||||
|
||||
For more information on the API, consult the OpenAPI documentation of `text-generation-inference` available [here](https://huggingface.github.io/text-generation-inference).
|
||||
|
||||
You can make the requests using any tool of your preference, such as curl, Python, or TypeScript. For an end-to-end experience, we've open-sourced [ChatUI](https://github.com/huggingface/chat-ui), a chat interface for open-access models.
|
||||
|
||||
## curl
|
||||
|
||||
After the launch, you can query the model using either the `/generate` or `/generate_stream` routes:
|
||||
After a successful server launch, you can query the model using the `v1/chat/completions` route, to get responses that are compliant to the OpenAI Chat Completion spec:
|
||||
|
||||
```bash
|
||||
curl localhost:8080/v1/chat/completions \
|
||||
-X POST \
|
||||
-d '{
|
||||
"model": "tgi",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is deep learning?"
|
||||
}
|
||||
],
|
||||
"stream": true,
|
||||
"max_tokens": 20
|
||||
}' \
|
||||
-H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
For non-chat use-cases, you can also use the `/generate` and `/generate_stream` routes.
|
||||
|
||||
```bash
|
||||
curl 127.0.0.1:8080/generate \
|
||||
-X POST \
|
||||
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
|
||||
-d '{
|
||||
"inputs":"What is Deep Learning?",
|
||||
"parameters":{
|
||||
"max_new_tokens":20
|
||||
}
|
||||
}' \
|
||||
-H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
## Python
|
||||
|
||||
## Inference Client
|
||||
### Inference Client
|
||||
|
||||
[`huggingface-hub`](https://huggingface.co/docs/huggingface_hub/main/en/index) is a Python library to interact with the Hugging Face Hub, including its endpoints. It provides a nice high-level class, [`~huggingface_hub.InferenceClient`], which makes it easy to make calls to a TGI endpoint. `InferenceClient` also takes care of parameter validation and provides a simple to-use interface.
|
||||
You can simply install `huggingface-hub` package with pip.
|
||||
[`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/main/en/index) is a Python library to interact with the Hugging Face Hub, including its endpoints. It provides a high-level class, [`huggingface_hub.InferenceClient`](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient), which makes it easy to make calls to TGI's Messages API. `InferenceClient` also takes care of parameter validation and provides a simple-to-use interface.
|
||||
|
||||
Install `huggingface_hub` package via pip.
|
||||
|
||||
```bash
|
||||
pip install huggingface-hub
|
||||
pip install huggingface_hub
|
||||
```
|
||||
|
||||
Once you start the TGI server, instantiate `InferenceClient()` with the URL to the endpoint serving the model. You can then call `text_generation()` to hit the endpoint through Python.
|
||||
You can now use `InferenceClient` the exact same way you would use `OpenAI` client in Python
|
||||
|
||||
```python
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
client = InferenceClient(model="http://127.0.0.1:8080")
|
||||
client.text_generation(prompt="Write a code for snake game")
|
||||
client = InferenceClient(
|
||||
base_url="http://localhost:8080/v1/",
|
||||
)
|
||||
|
||||
output = client.chat.completions.create(
|
||||
model="tgi",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Count to 10"},
|
||||
],
|
||||
stream=True,
|
||||
max_tokens=1024,
|
||||
)
|
||||
|
||||
for chunk in output:
|
||||
print(chunk.choices[0].delta.content)
|
||||
```
|
||||
|
||||
You can do streaming with `InferenceClient` by passing `stream=True`. Streaming will return tokens as they are being generated in the server. To use streaming, you can do as follows:
|
||||
You can check out more details about OpenAI compatibility [here](https://huggingface.co/docs/huggingface_hub/en/guides/inference#openai-compatibility).
|
||||
|
||||
There is also an async version of the client, `AsyncInferenceClient`, based on `asyncio` and `aiohttp`. You can find docs for it [here](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.AsyncInferenceClient)
|
||||
|
||||
### OpenAI Client
|
||||
|
||||
You can directly use the OpenAI [Python](https://github.com/openai/openai-python) or [JS](https://github.com/openai/openai-node) clients to interact with TGI.
|
||||
|
||||
Install the OpenAI Python package via pip.
|
||||
|
||||
```bash
|
||||
pip install openai
|
||||
```
|
||||
|
||||
```python
|
||||
for token in client.text_generation("How do you make cheese?", max_new_tokens=12, stream=True):
|
||||
print(token)
|
||||
from openai import OpenAI
|
||||
|
||||
# init the client but point it to TGI
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8080/v1/",
|
||||
api_key="-"
|
||||
)
|
||||
|
||||
chat_completion = client.chat.completions.create(
|
||||
model="tgi",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant." },
|
||||
{"role": "user", "content": "What is deep learning?"}
|
||||
],
|
||||
stream=True
|
||||
)
|
||||
|
||||
# iterate and print stream
|
||||
for message in chat_completion:
|
||||
print(message)
|
||||
```
|
||||
|
||||
Another parameter you can use with TGI backend is `details`. You can get more details on generation (tokens, probabilities, etc.) by setting `details` to `True`. When it's specified, TGI will return a `TextGenerationResponse` or `TextGenerationStreamResponse` rather than a string or stream.
|
||||
## UI
|
||||
|
||||
```python
|
||||
output = client.text_generation(prompt="Meaning of life is", details=True)
|
||||
print(output)
|
||||
|
||||
# TextGenerationResponse(generated_text=' a complex concept that is not always clear to the individual. It is a concept that is not always', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=20, seed=None, prefill=[], tokens=[Token(id=267, text=' a', logprob=-2.0723474, special=False), Token(id=11235, text=' complex', logprob=-3.1272552, special=False), Token(id=17908, text=' concept', logprob=-1.3632495, special=False),..))
|
||||
```
|
||||
|
||||
You can see how to stream below.
|
||||
|
||||
```python
|
||||
output = client.text_generation(prompt="Meaning of life is", stream=True, details=True)
|
||||
print(next(iter(output)))
|
||||
|
||||
# TextGenerationStreamResponse(token=Token(id=267, text=' a', logprob=-2.0723474, special=False), generated_text=None, details=None)
|
||||
```
|
||||
|
||||
You can check out the details of the function [here](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation). There is also an async version of the client, `AsyncInferenceClient`, based on `asyncio` and `aiohttp`. You can find docs for it [here](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.AsyncInferenceClient)
|
||||
|
||||
|
||||
## ChatUI
|
||||
|
||||
ChatUI is an open-source interface built for LLM serving. It offers many customization options, such as web search with SERP API and more. ChatUI can automatically consume the TGI server and even provides an option to switch between different TGI endpoints. You can try it out at [Hugging Chat](https://huggingface.co/chat/), or use the [ChatUI Docker Space](https://huggingface.co/new-space?template=huggingchat/chat-ui-template) to deploy your own Hugging Chat to Spaces.
|
||||
|
||||
To serve both ChatUI and TGI in same environment, simply add your own endpoints to the `MODELS` variable in `.env.local` file inside the `chat-ui` repository. Provide the endpoints pointing to where TGI is served.
|
||||
|
||||
```
|
||||
{
|
||||
// rest of the model config here
|
||||
"endpoints": [{"url": "https://HOST:PORT/generate_stream"}]
|
||||
}
|
||||
```
|
||||
|
||||

|
||||
|
||||
## Gradio
|
||||
### Gradio
|
||||
|
||||
Gradio is a Python library that helps you build web applications for your machine learning models with a few lines of code. It has a `ChatInterface` wrapper that helps create neat UIs for chatbots. Let's take a look at how to create a chatbot with streaming mode using TGI and Gradio. Let's install Gradio and Hub Python library first.
|
||||
|
||||
@ -89,19 +133,28 @@ Assume you are serving your model on port 8080, we will query through [Inference
|
||||
import gradio as gr
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
client = InferenceClient(model="http://127.0.0.1:8080")
|
||||
client = InferenceClient(base_url="http://127.0.0.1:8080")
|
||||
|
||||
def inference(message, history):
|
||||
partial_message = ""
|
||||
for token in client.text_generation(message, max_new_tokens=20, stream=True):
|
||||
partial_message += token
|
||||
output = client.chat.completions.create(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": message},
|
||||
],
|
||||
stream=True,
|
||||
max_tokens=1024,
|
||||
)
|
||||
|
||||
for chunk in output:
|
||||
partial_message += chunk.choices[0].delta.content
|
||||
yield partial_message
|
||||
|
||||
gr.ChatInterface(
|
||||
inference,
|
||||
chatbot=gr.Chatbot(height=300),
|
||||
textbox=gr.Textbox(placeholder="Chat with me!", container=False, scale=7),
|
||||
description="This is the demo for Gradio UI consuming TGI endpoint with LLaMA 7B-Chat model.",
|
||||
description="This is the demo for Gradio UI consuming TGI endpoint.",
|
||||
title="Gradio 🤝 TGI",
|
||||
examples=["Are tomatoes vegetables?"],
|
||||
retry_btn="Retry",
|
||||
@ -110,20 +163,7 @@ gr.ChatInterface(
|
||||
).queue().launch()
|
||||
```
|
||||
|
||||
The UI looks like this 👇
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img
|
||||
class="block dark:hidden"
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/gradio-tgi.png"
|
||||
/>
|
||||
<img
|
||||
class="hidden dark:block"
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/gradio-tgi-dark.png"
|
||||
/>
|
||||
</div>
|
||||
|
||||
You can try the demo directly here 👇
|
||||
You can check out the UI and try the demo directly here 👇
|
||||
|
||||
<div class="block dark:hidden">
|
||||
<iframe
|
||||
@ -141,15 +181,19 @@ You can try the demo directly here 👇
|
||||
</div>
|
||||
|
||||
|
||||
You can disable streaming mode using `return` instead of `yield` in your inference function, like below.
|
||||
|
||||
```python
|
||||
def inference(message, history):
|
||||
return client.text_generation(message, max_new_tokens=20)
|
||||
```
|
||||
|
||||
You can read more about how to customize a `ChatInterface` [here](https://www.gradio.app/guides/creating-a-chatbot-fast).
|
||||
|
||||
## API documentation
|
||||
### ChatUI
|
||||
|
||||
You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route. The Swagger UI is also available [here](https://huggingface.github.io/text-generation-inference).
|
||||
[ChatUI](https://github.com/huggingface/chat-ui) is an open-source interface built for consuming LLMs. It offers many customization options, such as web search with SERP API and more. ChatUI can automatically consume the TGI server and even provides an option to switch between different TGI endpoints. You can try it out at [Hugging Chat](https://huggingface.co/chat/), or use the [ChatUI Docker Space](https://huggingface.co/new-space?template=huggingchat/chat-ui-template) to deploy your own Hugging Chat to Spaces.
|
||||
|
||||
To serve both ChatUI and TGI in same environment, simply add your own endpoints to the `MODELS` variable in `.env.local` file inside the `chat-ui` repository. Provide the endpoints pointing to where TGI is served.
|
||||
|
||||
```
|
||||
{
|
||||
// rest of the model config here
|
||||
"endpoints": [{"url": "https://HOST:PORT/generate_stream"}]
|
||||
}
|
||||
```
|
||||
|
||||

|
||||
|
@ -19,6 +19,6 @@ docker run --gpus all \
|
||||
--shm-size 1g \
|
||||
-e HF_TOKEN=$token \
|
||||
-p 8080:80 \
|
||||
-v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.3 \
|
||||
-v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.4 \
|
||||
--model-id $model
|
||||
```
|
||||
|
@ -4,7 +4,7 @@ Text Generation Inference improves the model in several aspects.
|
||||
|
||||
## Quantization
|
||||
|
||||
TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes), [GPT-Q](https://arxiv.org/abs/2210.17323) and [AWQ](https://arxiv.org/abs/2306.00978) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq` or `awq` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq) when using AWQ quantization, you need to point to one of the models [here](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to [quantization guide](./../conceptual/quantization)
|
||||
TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes), [GPT-Q](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [Marlin](https://github.com/IST-DASLab/marlin), [EETQ](https://github.com/NetEase-FuXi/EETQ), [EXL2](https://github.com/turboderp/exllamav2), and [fp8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq`, `awq`, `marlin`, `exl2`, `eetq` or `fp8` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq). Similarly, when using AWQ quantization, you need to point to one of [these models](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to [quantization guide](./../conceptual/quantization)
|
||||
|
||||
|
||||
## RoPE Scaling
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Train Medusa
|
||||
|
||||
This tutorial will show you how to train a Medusa model on a dataset of your choice. Please check out the [speculation documentation](../conceptual/speculation.md) for more information on how Medusa works and speculation in general.
|
||||
This tutorial will show you how to train a Medusa model on a dataset of your choice. Please check out the [speculation documentation](../conceptual/speculation) for more information on how Medusa works and speculation in general.
|
||||
|
||||
## What are the benefits of training a Medusa model?
|
||||
|
||||
|
@ -4,7 +4,7 @@ Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-
|
||||
|
||||
These feature are available starting from version `1.4.3`. They are accessible via the [`huggingface_hub`](https://pypi.org/project/huggingface-hub/) library. The tool support is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!
|
||||
|
||||
_note: guidance is supported as grammar in the `/generate` endpoint and as tools in the `/chat/completions` endpoint._
|
||||
_note: guidance is supported as grammar in the `/generate` endpoint and as tools in the `v1/chat/completions` endpoint._
|
||||
|
||||
## How it works
|
||||
|
||||
@ -157,7 +157,12 @@ from huggingface_hub import InferenceClient
|
||||
|
||||
client = InferenceClient("http://localhost:3000")
|
||||
|
||||
regexp = "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)"
|
||||
section_regex = "(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"
|
||||
regexp = f"HELLO\.{section_regex}\.WORLD\.{section_regex}"
|
||||
|
||||
# This is a more realistic example of an ip address regex
|
||||
# regexp = f"{section_regex}\.{section_regex}\.{section_regex}\.{section_regex}"
|
||||
|
||||
|
||||
resp = client.text_generation(
|
||||
f"Whats Googles DNS? Please use the following regex: {regexp}",
|
||||
@ -170,7 +175,7 @@ resp = client.text_generation(
|
||||
|
||||
|
||||
print(resp)
|
||||
# 7.1.1.1
|
||||
# HELLO.255.WORLD.255
|
||||
|
||||
```
|
||||
|
||||
@ -306,11 +311,13 @@ print(chat.choices[0].message.tool_calls)
|
||||
|
||||
```
|
||||
|
||||
### OpenAI integration
|
||||
### OpenAI Integration
|
||||
|
||||
TGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.
|
||||
Text Generation Inference (TGI) offers seamless integration with OpenAI's client libraries, allowing developers to interact with TGI's Messages API and Tool functions in a familiar way. This compatibility simplifies the implementation of advanced features, such as tools and grammar, within your applications using OpenAI’s client.
|
||||
|
||||
However there are some minor differences in the API, for example `tool_choice="auto"` will ALWAYS choose the tool for you. This is different from OpenAI's API where `tool_choice="auto"` will choose a tool if the model thinks it's necessary.
|
||||
Previously, TGI handled tool selection differently than OpenAI’s API—`tool_choice="auto"` would always pick a tool for you. However, as of the latest version, TGI now mimics OpenAI’s behavior more closely: `tool_choice="auto"` selects a tool only when the model deems it necessary, aligning with how OpenAI's API works. This enhancement ensures a smoother and more predictable integration experience.
|
||||
|
||||
Additionally, error notifications like `notify_error`, which previously indicated that no tool was chosen, are no longer returned. Instead, TGI will proceed with generating a response as if no tool was selected, further improving consistency with OpenAI's API.
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
@ -84,7 +84,7 @@ print(chat)
|
||||
|
||||
```
|
||||
|
||||
or with OpenAi's library:
|
||||
or with OpenAI's [client library](https://github.com/openai/openai-python):
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
4
docs/source/conceptual/external.md
Normal file
4
docs/source/conceptual/external.md
Normal file
@ -0,0 +1,4 @@
|
||||
# External Resources
|
||||
|
||||
- Adyen wrote a detailed article about the interplay between TGI's main components: router and server.
|
||||
[LLM inference at scale with TGI (Martin Iglesias Goyanes - Adyen, 2024)](https://www.adyen.com/knowledge-hub/llm-inference-at-scale-with-tgi)
|
@ -2,11 +2,11 @@
|
||||
|
||||
## What is Guidance?
|
||||
|
||||
Guidance is a feature that allows users to constrain the generation of a large language model with a specified grammar. This feature is particularly useful when you want to generate text that follows a specific structure or uses a specific set of words or produce output in a specific format.
|
||||
Guidance is a feature that allows users to constrain the generation of a large language model with a specified grammar. This feature is particularly useful when you want to generate text that follows a specific structure or uses a specific set of words or produce output in a specific format. A prominent example is JSON grammar, where the model is forced to output valid JSON.
|
||||
|
||||
## How is it used?
|
||||
|
||||
Guidance can be in many ways and the community is always finding new ways to use it. Here are some examples of how you can use guidance:
|
||||
Guidance can be implemented in many ways and the community is always finding new ways to use it. Here are some examples of how you can use guidance:
|
||||
|
||||
Technically, guidance can be used to generate:
|
||||
|
||||
|
99
docs/source/conceptual/lora.md
Normal file
99
docs/source/conceptual/lora.md
Normal file
@ -0,0 +1,99 @@
|
||||
# LoRA (Low-Rank Adaptation)
|
||||
|
||||
## What is LoRA?
|
||||
|
||||
LoRA is a technique that allows for efficent fine-tuning a model while only updating a small portion of the model's weights. This is useful when you have a large model that has been pre-trained on a large dataset, but you want to fine-tune it on a smaller dataset or for a specific task.
|
||||
|
||||
LoRA works by adding a small number of additional weights to the model, which are used to adapt the model to the new dataset or task. These additional weights are learned during the fine-tuning process, while the rest of the model's weights are kept fixed.
|
||||
|
||||
## How is it used?
|
||||
|
||||
LoRA can be used in many ways and the community is always finding new ways to use it. Here are some examples of how you can use LoRA:
|
||||
|
||||
Technically, LoRA can be used to fine-tune a large language model on a small dataset. However, these use cases can span a wide range of applications, such as:
|
||||
|
||||
- fine-tuning a language model on a small dataset
|
||||
- fine-tuning a language model on a domain-specific dataset
|
||||
- fine-tuning a language model on a dataset with limited labels
|
||||
|
||||
## Optimizing Inference with LoRA
|
||||
|
||||
LoRA's can be used during inference by mutliplying the adapter weights with the model weights at each specified layer. This process can be computationally expensive, but due to awesome work by [punica-ai](https://github.com/punica-ai/punica) and the [lorax](https://github.com/predibase/lorax) team, optimized kernels/and frameworks have been developed to make this process more efficient. TGI leverages these optimizations in order to provide fast and efficient inference with mulitple LoRA models.
|
||||
|
||||
## Serving multiple LoRA adapters with TGI
|
||||
|
||||
Once a LoRA model has been trained, it can be used to generate text or perform other tasks just like a regular language model. However, because the model has been fine-tuned on a specific dataset, it may perform better on that dataset than a model that has not been fine-tuned.
|
||||
|
||||
In practice its often useful to have multiple LoRA models, each fine-tuned on a different dataset or for a different task. This allows you to use the model that is best suited for a particular task or dataset.
|
||||
|
||||
Text Generation Inference (TGI) now supports loading multiple LoRA models at startup that can be used in generation requests. This feature is available starting from version `~2.0.6` and is compatible with LoRA models trained using the `peft` library.
|
||||
|
||||
### Specifying LoRA models
|
||||
|
||||
To use LoRA in TGI, when starting the server, you can specify the list of LoRA models to load using the `LORA_ADAPTERS` environment variable. For example:
|
||||
|
||||
```bash
|
||||
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
|
||||
```
|
||||
|
||||
To specify model revision, use `adapter_id@revision`, as follows:
|
||||
|
||||
```bash
|
||||
LORA_ADAPTERS=predibase/customer_support@main,predibase/dbpedia@rev2
|
||||
```
|
||||
|
||||
To use a locally stored lora adapter, use `adapter-name=/path/to/adapter`, as seen below. When you want to use this adapter, set `"parameters": {"adapter_id": "adapter-name"}"`
|
||||
|
||||
```bash
|
||||
LORA_ADAPTERS=myadapter=/some/path/to/adapter,myadapter2=/another/path/to/adapter
|
||||
```
|
||||
|
||||
note it's possible to mix adapter_ids with adapter_id=adapter_path e.g.
|
||||
|
||||
```bash
|
||||
LORA_ADAPTERS=predibase/dbpedia,myadapter=/path/to/dir/
|
||||
```
|
||||
|
||||
In the server logs, you will see the following message:
|
||||
|
||||
```txt
|
||||
Loading adapter weights into model: predibase/customer_support
|
||||
Loading adapter weights into model: predibase/dbpedia
|
||||
```
|
||||
|
||||
## Generate text
|
||||
|
||||
You can then use these models in generation requests by specifying the `lora_model` parameter in the request payload. For example:
|
||||
|
||||
```json
|
||||
curl 127.0.0.1:3000/generate \
|
||||
-X POST \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"inputs": "Hello who are you?",
|
||||
"parameters": {
|
||||
"max_new_tokens": 40,
|
||||
"adapter_id": "predibase/customer_support"
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
If you are using a lora adapter stored locally that was set in the following manner: `LORA_ADAPTERS=myadapter=/some/path/to/adapter`, here is an example payload:
|
||||
|
||||
```json
|
||||
curl 127.0.0.1:3000/generate \
|
||||
-X POST \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"inputs": "Hello who are you?",
|
||||
"parameters": {
|
||||
"max_new_tokens": 40,
|
||||
"adapter_id": "myadapter"
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
> **Note:** The Lora feature is new and still being improved. If you encounter any issues or have any feedback, please let us know by opening an issue on the [GitHub repository](https://github.com/huggingface/text-generation-inference/issues/new/choose). Additionally documentation and an improved client library will be published soon.
|
||||
|
||||
An updated tutorial with detailed examples will be published soon. Stay tuned!
|
@ -1,6 +1,40 @@
|
||||
# Quantization
|
||||
|
||||
TGI offers GPTQ and bits-and-bytes quantization to quantize large language models.
|
||||
TGI offers many quantization schemes to run LLMs effectively and fast based on your use-case. TGI supports GPTQ, AWQ, bits-and-bytes, EETQ, Marlin, EXL2 and fp8 quantization.
|
||||
|
||||
To leverage GPTQ, AWQ, Marlin and EXL2 quants, you must provide pre-quantized weights. Whereas for bits-and-bytes, EETQ and fp8, weights are quantized by TGI on the fly.
|
||||
|
||||
We recommend using the official quantization scripts for creating your quants:
|
||||
1. [AWQ](https://github.com/casper-hansen/AutoAWQ/blob/main/examples/quantize.py)
|
||||
2. [GPTQ/ Marlin](https://github.com/AutoGPTQ/AutoGPTQ/blob/main/examples/quantization/basic_usage.py)
|
||||
3. [EXL2](https://github.com/turboderp/exllamav2/blob/master/doc/convert.md)
|
||||
|
||||
For on-the-fly quantization you simply need to pass one of the supported quantization types and TGI takes care of the rest.
|
||||
|
||||
## Quantization with bitsandbytes, EETQ & fp8
|
||||
|
||||
bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models. Unlike GPTQ quantization, bitsandbytes doesn't require a calibration dataset or any post-processing – weights are automatically quantized on load. However, inference with bitsandbytes is slower than GPTQ or FP16 precision.
|
||||
|
||||
8-bit quantization enables multi-billion parameter scale models to fit in smaller hardware without degrading performance too much.
|
||||
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
|
||||
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes
|
||||
```
|
||||
|
||||
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
|
||||
|
||||
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
|
||||
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes-nf4
|
||||
```
|
||||
|
||||
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
|
||||
|
||||
Similarly you can use pass you can pass `--quantize eetq` or `--quantize fp8` for respective quantization schemes.
|
||||
|
||||
In addition to this, TGI allows creating GPTQ quants directly by passing the model weights and a calibration dataset.
|
||||
|
||||
## Quantization with GPTQ
|
||||
|
||||
@ -36,24 +70,3 @@ You can learn more about the quantization options by running `text-generation-se
|
||||
|
||||
If you wish to do more with GPTQ models (e.g. train an adapter on top), you can read about transformers GPTQ integration [here](https://huggingface.co/blog/gptq-integration).
|
||||
You can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf).
|
||||
|
||||
## Quantization with bitsandbytes
|
||||
|
||||
bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models. Unlike GPTQ quantization, bitsandbytes doesn't require a calibration dataset or any post-processing – weights are automatically quantized on load. However, inference with bitsandbytes is slower than GPTQ or FP16 precision.
|
||||
|
||||
8-bit quantization enables multi-billion parameter scale models to fit in smaller hardware without degrading performance too much.
|
||||
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
|
||||
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes
|
||||
```
|
||||
|
||||
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
|
||||
|
||||
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
|
||||
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes-nf4
|
||||
```
|
||||
|
||||
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Streaming
|
||||
|
||||
|
||||
## What is Streaming?
|
||||
|
||||
Token streaming is the mode in which the server returns the tokens one by one as the model generates them. This enables showing progressive generations to the user rather than waiting for the whole generation. Streaming is an essential aspect of the end-user experience as it reduces latency, one of the most critical aspects of a smooth experience.
|
||||
@ -48,34 +49,29 @@ To stream tokens with `InferenceClient`, simply pass `stream=True` and iterate o
|
||||
```python
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
client = InferenceClient("http://127.0.0.1:8080")
|
||||
for token in client.text_generation("How do you make cheese?", max_new_tokens=12, stream=True):
|
||||
print(token)
|
||||
client = InferenceClient(base_url="http://127.0.0.1:8080")
|
||||
output = client.chat.completions.create(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Count to 10"},
|
||||
],
|
||||
stream=True,
|
||||
max_tokens=1024,
|
||||
)
|
||||
|
||||
# To
|
||||
# make
|
||||
# cheese
|
||||
#,
|
||||
# you
|
||||
# need
|
||||
# to
|
||||
# start
|
||||
# with
|
||||
# milk
|
||||
#.
|
||||
```
|
||||
for chunk in output:
|
||||
print(chunk.choices[0].delta.content)
|
||||
|
||||
If you want additional details, you can add `details=True`. In this case, you get a `TextGenerationStreamResponse` which contains additional information such as the probabilities and the tokens. For the final response in the stream, it also returns the full generated text.
|
||||
|
||||
```python
|
||||
for details in client.text_generation("How do you make cheese?", max_new_tokens=12, details=True, stream=True):
|
||||
print(details)
|
||||
|
||||
#TextGenerationStreamResponse(token=Token(id=193, text='\n', logprob=-0.007358551, special=False), generated_text=None, details=None)
|
||||
#TextGenerationStreamResponse(token=Token(id=2044, text='To', logprob=-1.1357422, special=False), generated_text=None, details=None)
|
||||
#TextGenerationStreamResponse(token=Token(id=717, text=' make', logprob=-0.009841919, special=False), generated_text=None, details=None)
|
||||
#...
|
||||
#TextGenerationStreamResponse(token=Token(id=25, text='.', logprob=-1.3408203, special=False), generated_text='\nTo make cheese, you need to start with milk.', details=StreamDetails(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=12, seed=None))
|
||||
# 1
|
||||
# 2
|
||||
# 3
|
||||
# 4
|
||||
# 5
|
||||
# 6
|
||||
# 7
|
||||
# 8
|
||||
# 9
|
||||
# 10
|
||||
```
|
||||
|
||||
The `huggingface_hub` library also comes with an `AsyncInferenceClient` in case you need to handle the requests concurrently.
|
||||
@ -83,31 +79,46 @@ The `huggingface_hub` library also comes with an `AsyncInferenceClient` in case
|
||||
```python
|
||||
from huggingface_hub import AsyncInferenceClient
|
||||
|
||||
client = AsyncInferenceClient("http://127.0.0.1:8080")
|
||||
async for token in await client.text_generation("How do you make cheese?", stream=True):
|
||||
print(token)
|
||||
client = AsyncInferenceClient(base_url="http://127.0.0.1:8080")
|
||||
async def main():
|
||||
stream = await client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": "Say this is a test"}],
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in stream:
|
||||
print(chunk.choices[0].delta.content or "", end="")
|
||||
|
||||
# To
|
||||
# make
|
||||
# cheese
|
||||
#,
|
||||
# you
|
||||
# need
|
||||
# to
|
||||
# start
|
||||
# with
|
||||
# milk
|
||||
asyncio.run(main())
|
||||
|
||||
# This
|
||||
# is
|
||||
# a
|
||||
# test
|
||||
#.
|
||||
```
|
||||
|
||||
### Streaming with cURL
|
||||
|
||||
To use the `generate_stream` endpoint with curl, you can add the `-N` flag, which disables curl default buffering and shows data as it arrives from the server
|
||||
To use the OpenAI Chat Completions compatible Messages API `v1/chat/completions` endpoint with curl, you can add the `-N` flag, which disables curl default buffering and shows data as it arrives from the server
|
||||
|
||||
```curl
|
||||
curl -N 127.0.0.1:8080/generate_stream \
|
||||
curl localhost:8080/v1/chat/completions \
|
||||
-X POST \
|
||||
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
|
||||
-d '{
|
||||
"model": "tgi",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is deep learning?"
|
||||
}
|
||||
],
|
||||
"stream": true,
|
||||
"max_tokens": 20
|
||||
}' \
|
||||
-H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
|
@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
|
||||
--device=/dev/kfd --device=/dev/dri --group-add video \
|
||||
--ipc=host --shm-size 256g --net host -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.0.3-rocm \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.1-rocm \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
@ -27,10 +27,16 @@ TunableOp is enabled by default, the warmup may take 1-2 minutes. In case you wo
|
||||
|
||||
## Flash attention implementation
|
||||
|
||||
Two implementations of Flash Attention are available for ROCm, the first is [ROCm/flash-attention](https://github.com/ROCm/flash-attention) based on a [Composable Kernel](https://github.com/ROCm/composable_kernel) (CK) implementation, and the second is a [Triton implementation](https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/utils/flash_attn_triton.py).
|
||||
Two implementations of Flash Attention are available for ROCm, the first is [ROCm/flash-attention](https://github.com/ROCm/flash-attention) based on a [Composable Kernel](https://github.com/ROCm/composable_kernel) (CK) implementation, and the second is a [Triton implementation](https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/layers/attention/flash_attn_triton.py).
|
||||
|
||||
By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container.
|
||||
|
||||
## Custom PagedAttention
|
||||
|
||||
For better performance on ROCm, a custom Paged Attention kernel is available and is enabled by default. To disable it and fall back to the PagedAttention v2 kernel, set the environment variable `ROCM_USE_CUSTOM_PAGED_ATTN=0`.
|
||||
|
||||
The custom kernel supports bf16 and fp16 data types, block size of 16, head size of 128, a maximum context length of 16k, and GQA ratios between 1 and 16. For other configurations, we use the PagedAttention v2 kernel.
|
||||
|
||||
## Unsupported features
|
||||
|
||||
The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
|
||||
|
36
docs/source/installation_intel.md
Normal file
36
docs/source/installation_intel.md
Normal file
@ -0,0 +1,36 @@
|
||||
# Using TGI with Intel GPUs
|
||||
|
||||
TGI optimized models are supported on Intel Data Center GPU [Max1100](https://www.intel.com/content/www/us/en/products/sku/232876/intel-data-center-gpu-max-1100/specifications.html), [Max1550](https://www.intel.com/content/www/us/en/products/sku/232873/intel-data-center-gpu-max-1550/specifications.html), the recommended usage is through Docker.
|
||||
|
||||
|
||||
On a server powered by Intel GPUs, TGI can be launched with the following command:
|
||||
|
||||
```bash
|
||||
model=teknium/OpenHermes-2.5-Mistral-7B
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --rm --privileged --cap-add=sys_nice \
|
||||
--device=/dev/dri \
|
||||
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.1-intel-xpu \
|
||||
--model-id $model --cuda-graphs 0
|
||||
```
|
||||
|
||||
# Using TGI with Intel CPUs
|
||||
|
||||
Intel® Extension for PyTorch (IPEX) also provides further optimizations for Intel CPUs. The IPEX provides optimization operations such as flash attention, page attention, Add + LayerNorm, ROPE and more.
|
||||
|
||||
On a server powered by Intel CPU, TGI can be launched with the following command:
|
||||
|
||||
```bash
|
||||
model=teknium/OpenHermes-2.5-Mistral-7B
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --rm --privileged --cap-add=sys_nice \
|
||||
--device=/dev/dri \
|
||||
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.3.1-intel-cpu \
|
||||
--model-id $model --cuda-graphs 0
|
||||
```
|
||||
|
||||
The launched TGI server can then be queried from clients, make sure to check out the [Consuming TGI](./basic_tutorials/consuming_tgi) guide.
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user