Merge branch 'main' into fp8_kvcache

This commit is contained in:
Mohit Sharma 2024-06-24 07:53:33 +00:00
commit 084de9907c
179 changed files with 15553 additions and 4079 deletions

View File

@ -18,51 +18,29 @@ on:
- "Cargo.lock" - "Cargo.lock"
- "rust-toolchain.toml" - "rust-toolchain.toml"
- "Dockerfile" - "Dockerfile"
- "Dockerfile_amd"
- "Dockerfile_intel"
branches: branches:
- 'main' - 'main'
jobs: jobs:
start-runner:
name: Start self-hosted EC2 runner
runs-on: ubuntu-latest
env:
AWS_REGION: us-east-1
EC2_AMI_ID: ami-0789b6925c11b1fb2
EC2_INSTANCE_TYPE: g5.12xlarge
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
EC2_SECURITY_GROUP: sg-030175c435ac141d6
outputs:
label: ${{ steps.start-ec2-runner.outputs.label }}
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
steps:
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v1
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
- name: Start EC2 runner
id: start-ec2-runner
uses: philschmid/philschmid-ec2-github-runner@main
with:
mode: start
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
ec2-image-id: ${{ env.EC2_AMI_ID }}
ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
subnet-id: ${{ env.EC2_SUBNET_ID }}
security-group-id: ${{ env.EC2_SECURITY_GROUP }}
aws-resource-tags: > # optional, requires additional permissions
[
{"Key": "Name", "Value": "ec2-tgi-github-runner"},
{"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
]
build-and-push-image: build-and-push-image:
concurrency: concurrency:
group: ${{ github.workflow }}-build-and-push-image-${{ github.head_ref || github.run_id }} group: ${{ github.workflow }}-build-and-push-image-${{ matrix.name }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true cancel-in-progress: true
needs: start-runner # required to start the main job when the runner is ready runs-on: [self-hosted, nvidia-gpu , multi-gpu, 4-a10, ci]
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner strategy:
matrix:
include:
- name: "cuda"
label: ""
dockerfile: "Dockerfile"
- name: "amd"
label: "-rocm"
dockerfile: "Dockerfile_amd"
- name: "intel"
label: "-intel"
dockerfile: "Dockerfile_intel"
permissions: permissions:
contents: write contents: write
packages: write packages: write
@ -73,16 +51,19 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1
- name: Tailscale
uses: huggingface/tailscale-action@main
with:
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
slackChannel: ${{ secrets.SLACK_CIFEEDBACK_CHANNEL }}
slackToken: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
- name: Initialize Docker Buildx - name: Initialize Docker Buildx
uses: docker/setup-buildx-action@v2.0.0 uses: docker/setup-buildx-action@v2.0.0
with: with:
install: true install: true
- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1
- name: Tailscale
uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966
with:
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
- name: Login to GitHub Container Registry - name: Login to GitHub Container Registry
if: github.event_name != 'pull_request' if: github.event_name != 'pull_request'
uses: docker/login-action@v2 uses: docker/login-action@v2
@ -112,7 +93,7 @@ jobs:
images: | images: |
registry.internal.huggingface.tech/api-inference/community/text-generation-inference registry.internal.huggingface.tech/api-inference/community/text-generation-inference
tags: | tags: |
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }} type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ matrix.label }}
# If main, release or tag # If main, release or tag
- name: Extract metadata (tags, labels) for Docker - name: Extract metadata (tags, labels) for Docker
if: ${{ github.event_name != 'pull_request' }} if: ${{ github.event_name != 'pull_request' }}
@ -126,308 +107,44 @@ jobs:
ghcr.io/huggingface/text-generation-inference ghcr.io/huggingface/text-generation-inference
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
tags: | tags: |
type=semver,pattern={{version}} type=semver,pattern={{version}}${{ matrix.label }}
type=semver,pattern={{major}}.{{minor}} type=semver,pattern={{major}}.{{minor}}${{ matrix.label }}
type=raw,value=latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }} type=raw,value=latest${{ matrix.label }},enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }} type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ matrix.label }}
- name: Build and push Docker image - name: Build and push Docker image
id: build-and-push id: build-and-push
uses: docker/build-push-action@v4 uses: docker/build-push-action@v4
with: with:
context: . context: .
file: Dockerfile file: ${{ matrix.dockerfile }}
push: true push: true
platforms: 'linux/amd64' platforms: 'linux/amd64'
build-args: | build-args: |
GIT_SHA=${{ env.GITHUB_SHA }} GIT_SHA=${{ env.GITHUB_SHA }}
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ matrix.label }}
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }} tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }} labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min network: host
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache${{ matrix.label }},mode=min
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache${{ matrix.label }},mode=min
integration-tests:
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
needs:
- start-runner
- build-and-push-image # Wait for the docker image to be built
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
env:
DOCKER_VOLUME: /cache
steps:
- uses: actions/checkout@v2
- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1
- name: Set up Python - name: Set up Python
if: matrix.name == 'cuda'
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:
python-version: 3.9 python-version: 3.9
- name: Tailscale
uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966
with:
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
- name: Prepare disks
run: |
sudo mkfs -t ext4 /dev/nvme1n1
sudo mkdir ${{ env.DOCKER_VOLUME }}
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
- name: Install - name: Install
if: matrix.name == 'cuda'
run: | run: |
make install-integration-tests make install-integration-tests
- name: Run tests - name: Run tests
if: matrix.name == 'cuda'
run: | run: |
export DOCKER_VOLUME=/mnt/cache
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }} export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} export HUGGING_FACE_HUB_TOKEN=${{ secrets.HF_TOKEN }}
pytest -s -vv integration-tests pytest -s -vv integration-tests
- name: Tailscale Wait
build-and-push-image-rocm: if: ${{ failure() || runner.debug == '1' }}
concurrency: uses: huggingface/tailscale-action@main
group: ${{ github.workflow }}-build-and-push-image-rocm-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
needs:
- start-runner
- build-and-push-image # Wait for the main docker image to be built
- integration-tests # Wait for the main integration-tests
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
permissions:
contents: write
packages: write
# This is used to complete the identity challenge
# with sigstore/fulcio when running outside of PRs.
id-token: write
security-events: write
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Initialize Docker Buildx
uses: docker/setup-buildx-action@v2.0.0
with: with:
install: true waitForSSH: true
- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1
- name: Tailscale
uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966
with:
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
- name: Login to GitHub Container Registry
if: github.event_name != 'pull_request'
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Login to internal Container Registry
uses: docker/login-action@v2.1.0
with:
username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }}
password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }}
registry: registry.internal.huggingface.tech
- name: Login to Azure Container Registry
if: github.event_name != 'pull_request'
uses: docker/login-action@v2.1.0
with:
username: ${{ secrets.AZURE_DOCKER_USERNAME }}
password: ${{ secrets.AZURE_DOCKER_PASSWORD }}
registry: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io
# If pull request
- name: Extract metadata (tags, labels) for Docker
if: ${{ github.event_name == 'pull_request' }}
id: meta-pr
uses: docker/metadata-action@v4.3.0
with:
images: |
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
tags: |
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}-rocm
# If main, release or tag
- name: Extract metadata (tags, labels) for Docker
if: ${{ github.event_name != 'pull_request' }}
id: meta
uses: docker/metadata-action@v4.3.0
with:
flavor: |
latest=false
images: |
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
ghcr.io/huggingface/text-generation-inference
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
tags: |
type=semver,pattern={{version}}-rocm
type=semver,pattern={{major}}.{{minor}}-rocm
type=raw,value=latest-rocm,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}-rocm
- name: Build and push Docker image
id: build-and-push
uses: docker/build-push-action@v4
with:
context: .
file: Dockerfile_amd
push: true
platforms: 'linux/amd64'
build-args: |
GIT_SHA=${{ env.GITHUB_SHA }}
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}-rocm
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min
build-and-push-image-intel:
concurrency:
group: ${{ github.workflow }}-build-and-push-image-intel-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
needs:
- start-runner
- build-and-push-image # Wait for the main docker image to be built
- integration-tests # Wait for the main integration-tests
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
permissions:
contents: write
packages: write
# This is used to complete the identity challenge
# with sigstore/fulcio when running outside of PRs.
id-token: write
security-events: write
outputs:
# env is not available in the later `container:`, but previous job outputs are.
short_sha: ${{ env.GITHUB_SHA_SHORT }}
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Initialize Docker Buildx
uses: docker/setup-buildx-action@v2.0.0
with:
install: true
- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1
- name: Tailscale
uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966
with:
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
- name: Login to GitHub Container Registry
if: github.event_name != 'pull_request'
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Login to internal Container Registry
uses: docker/login-action@v2.1.0
with:
username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }}
password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }}
registry: registry.internal.huggingface.tech
- name: Login to Azure Container Registry
if: github.event_name != 'pull_request'
uses: docker/login-action@v2.1.0
with:
username: ${{ secrets.AZURE_DOCKER_USERNAME }}
password: ${{ secrets.AZURE_DOCKER_PASSWORD }}
registry: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io
# If pull request
- name: Extract metadata (tags, labels) for Docker
if: ${{ github.event_name == 'pull_request' }}
id: meta-pr
uses: docker/metadata-action@v4.3.0
with:
images: |
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
tags: |
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}-intel
# If main, release or tag
- name: Extract metadata (tags, labels) for Docker
if: ${{ github.event_name != 'pull_request' }}
id: meta
uses: docker/metadata-action@v4.3.0
with:
flavor: |
latest=false
images: |
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
ghcr.io/huggingface/text-generation-inference
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
tags: |
type=semver,pattern={{version}}-intel
type=semver,pattern={{major}}.{{minor}}-intel
type=raw,value=latest-intel,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}-intel
- name: Build and push Docker image
id: build-and-push
uses: docker/build-push-action@v4
with:
context: .
file: Dockerfile_intel
push: true
platforms: 'linux/amd64'
build-args: |
GIT_SHA=${{ env.GITHUB_SHA }}
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}-intel
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-intel,mode=min
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-intel,mode=min
stop-runner:
name: Stop self-hosted EC2 runner
needs:
- start-runner
- build-and-push-image
- build-and-push-image-rocm
- build-and-push-image-intel
- integration-tests
runs-on: ubuntu-latest
env:
AWS_REGION: us-east-1
if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
steps:
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v1
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
- name: Stop EC2 runner
uses: philschmid/philschmid-ec2-github-runner@main
with:
mode: stop
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
label: ${{ needs.start-runner.outputs.label }}
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
# TODO: Move this to `build_amd.yml` (and `build_nvidia.yml`)
# integration-tests-rocm:
# concurrency:
# group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
# cancel-in-progress: true
# needs:
# - start-runner
# - build-and-push-image
# - integration-tests
# - build-and-push-image-rocm
# - stop-runner
# runs-on: [self-hosted, amd-gpu, multi-gpu, mi300]
# container:
# image: registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ needs.build-and-push-image-rocm.outputs.short_sha }}-rocm
# options: --device /dev/kfd --device /dev/dri --env ROCR_VISIBLE_DEVICES --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/cache
# env:
# DOCKER_VOLUME: /cache
# steps:
# - name: ROCM-SMI
# run: |
# rocm-smi
# - name: ROCM-INFO
# run: |
# rocminfo | grep "Agent" -A 14
# - name: Show ROCR environment
# run: |
# echo "ROCR: $ROCR_VISIBLE_DEVICES"
# - name: Install
# run: |
# make install-integration-tests
# - name: Run tests
# run: |
# export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
# pytest -s -vv integration-tests

View File

@ -33,9 +33,9 @@ jobs:
- name: Install Rust - name: Install Rust
uses: actions-rs/toolchain@v1 uses: actions-rs/toolchain@v1
with: with:
# Released on: 02 May, 2024 # Released on: June 13, 2024
# https://releases.rs/docs/1.78.0/ # https://releases.rs/docs/1.79.0/
toolchain: 1.78.0 toolchain: 1.79.0
override: true override: true
components: rustfmt, clippy components: rustfmt, clippy
- name: Install Protoc - name: Install Protoc
@ -68,11 +68,11 @@ jobs:
~/.cargo/git ~/.cargo/git
- name: Install - name: Install
run: | run: |
make install make install-cpu
- name: Run server tests - name: Run server tests
run: | run: |
pip install pytest pip install pytest
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} export HUGGING_FACE_HUB_TOKEN=${{ secrets.HF_TOKEN }}
pytest -s -vv server/tests pytest -s -vv server/tests
- name: Pre-commit checks - name: Pre-commit checks
run: | run: |

18
.github/workflows/trufflehog.yml vendored Normal file
View File

@ -0,0 +1,18 @@
on:
push:
name: Secret Leaks
permissions:
contents: read
jobs:
trufflehog:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main

133
CODE_OF_CONDUCT.md Normal file
View File

@ -0,0 +1,133 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, caste, color, religion, or sexual
identity and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the overall
community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or advances of
any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email address,
without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
feedback@huggingface.co.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series of
actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or permanent
ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within the
community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.1, available at
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
Community Impact Guidelines were inspired by
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
For answers to common questions about this code of conduct, see the FAQ at
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
[https://www.contributor-covenant.org/translations][translations].
[homepage]: https://www.contributor-covenant.org
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
[Mozilla CoC]: https://github.com/mozilla/diversity
[FAQ]: https://www.contributor-covenant.org/faq
[translations]: https://www.contributor-covenant.org/translations

120
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,120 @@
<!---
Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Contribute to text-generation-inference
Everyone is welcome to contribute, and we value everybody's contribution. Code
contributions are not the only way to help the community. Answering questions, helping
others, and improving the documentation are also immensely valuable.
It also helps us if you spread the word! Reference the library in blog posts
about the awesome projects it made possible, shout out on Twitter every time it has
helped you, or simply ⭐️ the repository to say thank you.
However you choose to contribute, please be mindful and respect our
[code of conduct](https://github.com/huggingface/text-generation-inference/blob/main/CODE_OF_CONDUCT.md).
**This guide was heavily inspired by the awesome [scikit-learn guide to contributing](https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md).**
## Ways to contribute
There are several ways you can contribute to text-generation-inference.
* Fix outstanding issues with the existing code.
* Submit issues related to bugs or desired new features.
* Contribute to the examples or to the documentation.
> All contributions are equally valuable to the community. 🥰
## Fixing outstanding issues
If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request) and open
a Pull Request!
## Submitting a bug-related issue or feature request
Do your best to follow these guidelines when submitting a bug-related issue or a feature
request. It will make it easier for us to come back to you quickly and with good
feedback.
### Did you find a bug?
The text-generation-inference library is robust and reliable thanks to users who report the problems they encounter.
Before you report an issue, we would really appreciate it if you could **make sure the bug was not
already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the
library itself, and not your code.
Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so
we can quickly resolve it:
* Your **OS type and version**, as well as your environment versions (versions of rust, python, and dependencies).
* A short, self-contained, code snippet that allows us to reproduce the bug.
* The *full* traceback if an exception is raised.
* Attach any other additional information, like screenshots, you think may help.
To get the OS and software versions automatically, you can re-run the launcher with the `--env` flag:
```bash
text-generation-launcher --env
```
This will precede the launch of the model with the information relative to your environment. We recommend pasting
that in your issue report.
### Do you want a new feature?
If there is a new feature you'd like to see in text-generation-inference, please open an issue and describe:
1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it
a feature related to something you need for a project? Is it something you worked on and think it could benefit
the community?
Whatever it is, we'd love to hear about it!
2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better
we'll be able to help you.
3. Provide a *code snippet* that demonstrates the feature's usage.
4. If the feature is related to a paper, please include a link.
If your issue is well written we're already 80% of the way there by the time you create it.
We have added [templates](https://github.com/huggingface/text-generation-inference/tree/main/.github/ISSUE_TEMPLATE)
to help you get started with your issue.
## Do you want to implement a new model?
New models are constantly released and if you want to implement a new model, please provide the following information:
* A short description of the model and a link to the paper.
* Link to the implementation if it is open-sourced.
* Link to the model weights if they are available.
If you are willing to contribute the model yourself, let us know so we can help you add it to text-generation-inference!
## Do you want to add documentation?
We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know
how the documentation can be improved such as typos and any content that is missing, unclear or inaccurate. We'll be
happy to make the changes or help you make a contribution if you're interested!
## I want to become a maintainer of the project. How do I get there?
TGI is a project led and managed by Hugging Face as it powers our internal services. However, we are happy to have
motivated individuals from other organizations join us as maintainers with the goal of making TGI the best inference
service.
If you are such an individual (or organization), please reach out to us and let's collaborate.

781
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -9,19 +9,29 @@ members = [
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "2.0.2" version = "2.0.5-dev0"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"] authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference" homepage = "https://github.com/huggingface/text-generation-inference"
[workspace.dependencies] [workspace.dependencies]
base64 = "0.22.0"
tokenizers = { version = "0.19.1", features = ["http"] } tokenizers = { version = "0.19.1", features = ["http"] }
hf-hub = { version = "0.3.1", features = ["tokio"] } hf-hub = { version = "0.3.1", features = ["tokio"] }
[profile.release] [profile.release]
incremental = true
[profile.release-binary]
inherits = "release"
debug = 1 debug = 1
incremental = true incremental = true
panic = "abort"
[profile.release-opt]
inherits = "release"
debug = 0
incremental = false
lto = "fat" lto = "fat"
opt-level = 3 opt-level = 3
codegen-units = 1 codegen-units = 1
panic = "abort"

View File

@ -1,5 +1,5 @@
# Rust builder # Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.78 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -15,9 +15,6 @@ RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder FROM chef AS builder
ARG GIT_SHA
ARG DOCKER_LABEL
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
@ -25,7 +22,10 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
rm -f $PROTOC_ZIP rm -f $PROTOC_ZIP
COPY --from=planner /usr/src/recipe.json recipe.json COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json RUN cargo chef cook --profile release-opt --recipe-path recipe.json
ARG GIT_SHA
ARG DOCKER_LABEL
COPY Cargo.toml Cargo.toml COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml COPY rust-toolchain.toml rust-toolchain.toml
@ -33,7 +33,7 @@ COPY proto proto
COPY benchmark benchmark COPY benchmark benchmark
COPY router router COPY router router
COPY launcher launcher COPY launcher launcher
RUN cargo build --release RUN cargo build --profile release-opt
# Python builder # Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
@ -137,6 +137,13 @@ COPY server/Makefile-eetq Makefile
# Build specific version of transformers # Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq
# Build marlin kernels
FROM kernel-builder as marlin-kernels-builder
WORKDIR /usr/src
COPY server/marlin/ .
# Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
# Build Transformers CUDA kernels # Build Transformers CUDA kernels
FROM kernel-builder as custom-kernels-builder FROM kernel-builder as custom-kernels-builder
WORKDIR /usr/src WORKDIR /usr/src
@ -193,7 +200,7 @@ COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from flash attention v2 builder # Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=flash-att-v2-builder /opt/conda/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from custom kernels builder # Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
@ -205,6 +212,8 @@ COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-31
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from eetq kernels builder # Copy build artifacts from eetq kernels builder
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from marlin kernels builder
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy builds artifacts from vllm builder # Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
@ -225,18 +234,22 @@ RUN cd server && \
pip install -r requirements_cuda.txt && \ pip install -r requirements_cuda.txt && \
pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir
# Install benchmarker # Deps before the binaries
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark # The binaries change on every build given we burn the SHA into them
# Install router # The deps change less often.
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \ build-essential \
g++ \ g++ \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
# AWS Sagemaker compatible image # AWS Sagemaker compatible image
FROM base as sagemaker FROM base as sagemaker

View File

@ -1,5 +1,5 @@
# Rust builder # Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.78 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -15,9 +15,6 @@ RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder FROM chef AS builder
ARG GIT_SHA
ARG DOCKER_LABEL
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
@ -25,7 +22,10 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
rm -f $PROTOC_ZIP rm -f $PROTOC_ZIP
COPY --from=planner /usr/src/recipe.json recipe.json COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json RUN cargo chef cook --profile release-opt --recipe-path recipe.json
ARG GIT_SHA
ARG DOCKER_LABEL
COPY Cargo.toml Cargo.toml COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml COPY rust-toolchain.toml rust-toolchain.toml
@ -33,7 +33,7 @@ COPY proto proto
COPY benchmark benchmark COPY benchmark benchmark
COPY router router COPY router router
COPY launcher launcher COPY launcher launcher
RUN cargo build --release RUN cargo build --profile release-opt
# Text Generation Inference base image for RoCm # Text Generation Inference base image for RoCm
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base
@ -193,11 +193,11 @@ RUN cd server && \
pip install ".[accelerate, peft, outlines]" --no-cache-dir pip install ".[accelerate, peft, outlines]" --no-cache-dir
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router # Install router
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher # Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
# AWS Sagemaker compatible image # AWS Sagemaker compatible image
FROM base as sagemaker FROM base as sagemaker

View File

@ -1,4 +1,4 @@
FROM lukemathwalker/cargo-chef:latest-rust-1.78 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -14,9 +14,6 @@ RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder FROM chef AS builder
ARG GIT_SHA
ARG DOCKER_LABEL
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
@ -24,7 +21,10 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
rm -f $PROTOC_ZIP rm -f $PROTOC_ZIP
COPY --from=planner /usr/src/recipe.json recipe.json COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json RUN cargo chef cook --profile release-opt --recipe-path recipe.json
ARG GIT_SHA
ARG DOCKER_LABEL
COPY Cargo.toml Cargo.toml COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml COPY rust-toolchain.toml rust-toolchain.toml
@ -32,7 +32,7 @@ COPY proto proto
COPY benchmark benchmark COPY benchmark benchmark
COPY router router COPY router router
COPY launcher launcher COPY launcher launcher
RUN cargo build --release RUN cargo build --profile release-opt
# Text Generation Inference base image for Intel # Text Generation Inference base image for Intel
@ -43,11 +43,12 @@ USER root
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \ RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
RUN apt-get update && apt install -y intel-basekit xpu-smi RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build
# Text Generation Inference base env # Text Generation Inference base env
ENV HUGGINGFACE_HUB_CACHE=/data \ ENV HUGGINGFACE_HUB_CACHE=/data \
@ -56,8 +57,8 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
WORKDIR /usr/src WORKDIR /usr/src
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
RUN pip install intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b group_rope origin/dev/gqa_rope
# Install server # Install server
COPY proto proto COPY proto proto
@ -65,7 +66,7 @@ COPY server server
COPY server/Makefile server/Makefile COPY server/Makefile server/Makefile
RUN cd server && \ RUN cd server && \
make gen-server && \ make gen-server && \
pip install -r requirements_cuda.txt && \ pip install -r requirements_intel.txt && \
pip install ".[accelerate, peft, outlines]" --no-cache-dir pip install ".[accelerate, peft, outlines]" --no-cache-dir
ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest
@ -75,13 +76,17 @@ ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/l
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64: ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:
ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
ENV CCL_ZE_IPC_EXCHANGE=sockets ENV CCL_ZE_IPC_EXCHANGE=sockets
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
RUN pip uninstall -y intel-extension-for-pytorch && cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router # Install router
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher # Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
# Final image # Final image
FROM base FROM base

View File

@ -1,12 +1,8 @@
install-server: install-server:
cd server && make install cd server && make install
install-custom-kernels: install-server-cpu:
if [ "$$BUILD_EXTENSIONS" = "True" ]; then cd server/custom_kernels && python setup.py install; else echo "Custom kernels are disabled, you need to set the BUILD_EXTENSIONS environment variable to 'True' in order to build them. (Please read the docs, kernels might not work on all hardware)"; fi cd server && make install-server
install-integration-tests:
cd integration-tests && pip install -r requirements.txt
cd clients/python && pip install .
install-router: install-router:
cd router && cargo install --path . cd router && cargo install --path .
@ -17,7 +13,10 @@ install-launcher:
install-benchmark: install-benchmark:
cd benchmark && cargo install --path . cd benchmark && cargo install --path .
install: install-server install-router install-launcher install-custom-kernels install: install-server install-router install-launcher
install-cpu: install-server-cpu install-router install-launcher
server-dev: server-dev:
cd server && make run-dev cd server && make run-dev
@ -28,6 +27,10 @@ router-dev:
rust-tests: install-router install-launcher rust-tests: install-router install-launcher
cargo test cargo test
install-integration-tests:
cd integration-tests && pip install -r requirements.txt
cd clients/python && pip install .
integration-tests: install-integration-tests integration-tests: install-integration-tests
pytest -s -vv -m "not private" integration-tests pytest -s -vv -m "not private" integration-tests

View File

@ -497,7 +497,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec<Line<'a>> {
"Lowest: {:.2} {unit}", "Lowest: {:.2} {unit}",
data.iter() data.iter()
.min_by(|a, b| a.total_cmp(b)) .min_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN) .unwrap_or(&f64::NAN)
), ),
Style::default().fg(Color::Reset), Style::default().fg(Color::Reset),
)]), )]),
@ -506,7 +506,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec<Line<'a>> {
"Highest: {:.2} {unit}", "Highest: {:.2} {unit}",
data.iter() data.iter()
.max_by(|a, b| a.total_cmp(b)) .max_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN) .unwrap_or(&f64::NAN)
), ),
Style::default().fg(Color::Reset), Style::default().fg(Color::Reset),
)]), )]),
@ -555,17 +555,17 @@ fn latency_throughput_chart<'a>(
let min_latency: f64 = *latency_iter let min_latency: f64 = *latency_iter
.clone() .clone()
.min_by(|a, b| a.total_cmp(b)) .min_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN); .unwrap_or(&f64::NAN);
let max_latency: f64 = *latency_iter let max_latency: f64 = *latency_iter
.max_by(|a, b| a.total_cmp(b)) .max_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN); .unwrap_or(&f64::NAN);
let min_throughput: f64 = *throughput_iter let min_throughput: f64 = *throughput_iter
.clone() .clone()
.min_by(|a, b| a.total_cmp(b)) .min_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN); .unwrap_or(&f64::NAN);
let max_throughput: f64 = *throughput_iter let max_throughput: f64 = *throughput_iter
.max_by(|a, b| a.total_cmp(b)) .max_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN); .unwrap_or(&f64::NAN);
// Char min max values // Char min max values
let min_x = if zoom { let min_x = if zoom {

View File

@ -1,8 +1,9 @@
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use text_generation_client::{ use text_generation_client::v3::{
Batch, CachedBatch, ClientError, NextTokenChooserParameters, Request, ShardedClient, Batch, CachedBatch, NextTokenChooserParameters, Request, ShardedClient,
StoppingCriteriaParameters, StoppingCriteriaParameters,
}; };
use text_generation_client::{Chunk, ClientError, Input};
use tokenizers::{Tokenizer, TruncationDirection}; use tokenizers::{Tokenizer, TruncationDirection};
use tokio::sync::{broadcast, mpsc}; use tokio::sync::{broadcast, mpsc};
@ -142,6 +143,9 @@ async fn prefill(
.map(|id| Request { .map(|id| Request {
id: id.into(), id: id.into(),
prefill_logprobs: false, prefill_logprobs: false,
input_chunks: Some(Input {
chunks: vec![Chunk::Text(sequence.clone()).into()],
}),
inputs: sequence.clone(), inputs: sequence.clone(),
truncate: sequence_length, truncate: sequence_length,
parameters: Some(parameters.clone()), parameters: Some(parameters.clone()),
@ -151,6 +155,8 @@ async fn prefill(
ignore_eos_token: true, // Will not stop even if a eos token is generated ignore_eos_token: true, // Will not stop even if a eos token is generated
}), }),
top_n_tokens: top_n_tokens.unwrap_or(0), top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![],
slots: vec![],
}) })
.collect(); .collect();
@ -159,6 +165,7 @@ async fn prefill(
requests, requests,
size: batch_size, size: batch_size,
max_tokens: batch_size * (sequence_length + decode_length), max_tokens: batch_size * (sequence_length + decode_length),
max_blocks: 0,
}; };
// Run prefill // Run prefill

View File

@ -8,7 +8,7 @@ use crate::app::App;
use crate::event::Event; use crate::event::Event;
use crossterm::ExecutableCommand; use crossterm::ExecutableCommand;
use std::io; use std::io;
use text_generation_client::{GrammarType, NextTokenChooserParameters, ShardedClient}; use text_generation_client::v3::{GrammarType, NextTokenChooserParameters, ShardedClient};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::sync::{broadcast, mpsc}; use tokio::sync::{broadcast, mpsc};
use tui::backend::CrosstermBackend; use tui::backend::CrosstermBackend;

View File

@ -4,7 +4,7 @@
/// and: https://github.com/orhun/rust-tui-template /// and: https://github.com/orhun/rust-tui-template
use clap::Parser; use clap::Parser;
use std::path::Path; use std::path::Path;
use text_generation_client::ShardedClient; use text_generation_client::v3::ShardedClient;
use tokenizers::{FromPretrainedParameters, Tokenizer}; use tokenizers::{FromPretrainedParameters, Tokenizer};
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;

View File

@ -156,17 +156,17 @@ fn avg_min_max(data: &[f64]) -> (f64, f64, f64) {
let min = data let min = data
.iter() .iter()
.min_by(|a, b| a.total_cmp(b)) .min_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN); .unwrap_or(&f64::NAN);
let max = data let max = data
.iter() .iter()
.max_by(|a, b| a.total_cmp(b)) .max_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN); .unwrap_or(&f64::NAN);
(average, *min, *max) (average, *min, *max)
} }
fn px(data: &[f64], p: u32) -> f64 { fn px(data: &[f64], p: u32) -> f64 {
let i = (f64::from(p) / 100.0 * data.len() as f64) as usize; let i = (f64::from(p) / 100.0 * data.len() as f64) as usize;
*data.get(i).unwrap_or(&std::f64::NAN) *data.get(i).unwrap_or(&f64::NAN)
} }
fn format_value(value: f64, unit: &'static str) -> String { fn format_value(value: f64, unit: &'static str) -> String {

View File

@ -37,7 +37,7 @@ pub(crate) fn percentiles(values: &[f64], pecents: &[i32]) -> BTreeMap<String, f
.iter() .iter()
.map(|&p| { .map(|&p| {
let i = (f64::from(p) / 100.0 * values.len() as f64) as usize; let i = (f64::from(p) / 100.0 * values.len() as f64) as usize;
(format!("p{p}"), *values.get(i).unwrap_or(&std::f64::NAN)) (format!("p{p}"), *values.get(i).unwrap_or(&f64::NAN))
}) })
.collect() .collect()
} }

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
__version__ = "0.6.0" __version__ = "0.7.0"
DEPRECATION_WARNING = ( DEPRECATION_WARNING = (
"`text_generation` clients are deprecated and will be removed in the near future. " "`text_generation` clients are deprecated and will be removed in the near future. "

View File

@ -13,6 +13,9 @@ from text_generation.types import (
Request, Request,
Parameters, Parameters,
Grammar, Grammar,
CompletionRequest,
Completion,
CompletionComplete,
ChatRequest, ChatRequest,
ChatCompletionChunk, ChatCompletionChunk,
ChatComplete, ChatComplete,
@ -70,6 +73,94 @@ class Client:
self.cookies = cookies self.cookies = cookies
self.timeout = timeout self.timeout = timeout
def completion(
self,
prompt: str,
frequency_penalty: Optional[float] = None,
max_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
seed: Optional[int] = None,
stream: bool = False,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
stop: Optional[List[str]] = None,
):
"""
Given a prompt, generate a response synchronously
Args:
prompt (`str`):
Prompt
frequency_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty
Penalize new tokens based on their existing frequency in the text so far,
decreasing the model's likelihood to repeat the same line verbatim.
max_tokens (`int`):
Maximum number of generated tokens
repetition_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
seed (`int`):
Random sampling seed
stream (`bool`):
Stream the response
temperature (`float`):
The value used to module the logits distribution.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation
stop (`List[str]`):
Stop generating tokens if a member of `stop` is generated
"""
request = CompletionRequest(
model="tgi",
prompt=prompt,
frequency_penalty=frequency_penalty,
max_tokens=max_tokens,
repetition_penalty=repetition_penalty,
seed=seed,
stream=stream,
temperature=temperature,
top_p=top_p,
stop=stop,
)
if not stream:
resp = requests.post(
f"{self.base_url}/v1/completions",
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
)
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
return Completion(**payload)
else:
return self._completion_stream_response(request)
def _completion_stream_response(self, request):
resp = requests.post(
f"{self.base_url}/v1/completions",
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
stream=True,
)
# iterate and print stream
for byte_payload in resp.iter_lines():
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
try:
response = CompletionComplete(**json_payload)
yield response
except ValidationError:
raise parse_error(resp.status, json_payload)
def chat( def chat(
self, self,
messages: List[Message], messages: List[Message],
@ -88,6 +179,7 @@ class Client:
tools: Optional[List[Tool]] = None, tools: Optional[List[Tool]] = None,
tool_prompt: Optional[str] = None, tool_prompt: Optional[str] = None,
tool_choice: Optional[str] = None, tool_choice: Optional[str] = None,
stop: Optional[List[str]] = None,
): ):
""" """
Given a list of messages, generate a response asynchronously Given a list of messages, generate a response asynchronously
@ -130,6 +222,8 @@ class Client:
A prompt to be appended before the tools A prompt to be appended before the tools
tool_choice (`str`): tool_choice (`str`):
The tool to use The tool to use
stop (`List[str]`):
Stop generating tokens if a member of `stop` is generated
""" """
request = ChatRequest( request = ChatRequest(
@ -150,6 +244,7 @@ class Client:
tools=tools, tools=tools,
tool_prompt=tool_prompt, tool_prompt=tool_prompt,
tool_choice=tool_choice, tool_choice=tool_choice,
stop=stop,
) )
if not stream: if not stream:
resp = requests.post( resp = requests.post(
@ -461,6 +556,93 @@ class AsyncClient:
self.cookies = cookies self.cookies = cookies
self.timeout = ClientTimeout(timeout) self.timeout = ClientTimeout(timeout)
async def completion(
self,
prompt: str,
frequency_penalty: Optional[float] = None,
max_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
seed: Optional[int] = None,
stream: bool = False,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
stop: Optional[List[str]] = None,
) -> Union[Completion, AsyncIterator[CompletionComplete]]:
"""
Given a prompt, generate a response asynchronously
Args:
prompt (`str`):
Prompt
frequency_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty
Penalize new tokens based on their existing frequency in the text so far,
decreasing the model's likelihood to repeat the same line verbatim.
max_tokens (`int`):
Maximum number of generated tokens
repetition_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
seed (`int`):
Random sampling seed
stream (`bool`):
Stream the response
temperature (`float`):
The value used to module the logits distribution.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation
stop (`List[str]`):
Stop generating tokens if a member of `stop` is generated
"""
request = CompletionRequest(
model="tgi",
prompt=prompt,
frequency_penalty=frequency_penalty,
max_tokens=max_tokens,
repetition_penalty=repetition_penalty,
seed=seed,
stream=stream,
temperature=temperature,
top_p=top_p,
stop=stop,
)
if not stream:
return await self._completion_single_response(request)
else:
return self._completion_stream_response(request)
async def _completion_single_response(self, request):
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(
f"{self.base_url}/v1/completions", json=request.dict()
) as resp:
payload = await resp.json()
if resp.status != 200:
raise parse_error(resp.status, payload)
return Completion(**payload)
async def _completion_stream_response(self, request):
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(
f"{self.base_url}/v1/completions", json=request.dict()
) as resp:
async for byte_payload in resp.content:
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
try:
response = CompletionComplete(**json_payload)
yield response
except ValidationError:
raise parse_error(resp.status, json_payload)
async def chat( async def chat(
self, self,
messages: List[Message], messages: List[Message],
@ -479,6 +661,7 @@ class AsyncClient:
tools: Optional[List[Tool]] = None, tools: Optional[List[Tool]] = None,
tool_prompt: Optional[str] = None, tool_prompt: Optional[str] = None,
tool_choice: Optional[str] = None, tool_choice: Optional[str] = None,
stop: Optional[List[str]] = None,
) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]: ) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:
""" """
Given a list of messages, generate a response asynchronously Given a list of messages, generate a response asynchronously
@ -521,6 +704,8 @@ class AsyncClient:
A prompt to be appended before the tools A prompt to be appended before the tools
tool_choice (`str`): tool_choice (`str`):
The tool to use The tool to use
stop (`List[str]`):
Stop generating tokens if a member of `stop` is generated
""" """
request = ChatRequest( request = ChatRequest(
@ -541,6 +726,7 @@ class AsyncClient:
tools=tools, tools=tools,
tool_prompt=tool_prompt, tool_prompt=tool_prompt,
tool_choice=tool_choice, tool_choice=tool_choice,
stop=stop,
) )
if not stream: if not stream:
return await self._chat_single_response(request) return await self._chat_single_response(request)

View File

@ -46,30 +46,6 @@ class Tool(BaseModel):
function: dict function: dict
class ChatCompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
message: Message
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
# Usage details of the chat completion
usage: Optional[Any] = None
class CompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
text: str
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
class Function(BaseModel): class Function(BaseModel):
name: Optional[str] name: Optional[str]
arguments: str arguments: str
@ -95,24 +71,41 @@ class Choice(BaseModel):
finish_reason: Optional[str] = None finish_reason: Optional[str] = None
class ChatCompletionChunk(BaseModel): class CompletionRequest(BaseModel):
id: str # Model identifier
object: str
created: int
model: str model: str
system_fingerprint: str # Prompt
choices: List[Choice] prompt: str
# The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float] = None
# The parameter for frequency penalty. 1.0 means no penalty
# Penalize new tokens based on their existing frequency in the text so far,
# decreasing the model's likelihood to repeat the same line verbatim.
frequency_penalty: Optional[float] = None
# Maximum number of tokens to generate
max_tokens: Optional[int] = None
# Flag to indicate streaming response
stream: bool = False
# Random sampling seed
seed: Optional[int] = None
# Sampling temperature
temperature: Optional[float] = None
# Top-p value for nucleus sampling
top_p: Optional[float] = None
# Stop generating tokens if a member of `stop` is generated
stop: Optional[List[str]] = None
class ChatComplete(BaseModel): class CompletionComplete(BaseModel):
# Chat completion details # Index of the chat completion
id: str index: int
object: str # Message associated with the chat completion
created: int text: str
model: str # Log probabilities for the chat completion
system_fingerprint: str logprobs: Optional[Any]
choices: List[ChatCompletionComplete] # Reason for completion
usage: Any finish_reason: str
class Completion(BaseModel): class Completion(BaseModel):
@ -163,6 +156,41 @@ class ChatRequest(BaseModel):
tool_prompt: Optional[str] = None tool_prompt: Optional[str] = None
# Choice of tool to be used # Choice of tool to be used
tool_choice: Optional[str] = None tool_choice: Optional[str] = None
# Stop generating tokens if a member of `stop` is generated
stop: Optional[List[str]] = None
class ChatCompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
message: Message
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
# Usage details of the chat completion
usage: Optional[Any] = None
class ChatComplete(BaseModel):
# Chat completion details
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[ChatCompletionComplete]
usage: Any
class ChatCompletionChunk(BaseModel):
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[Choice]
class Parameters(BaseModel): class Parameters(BaseModel):

10
docs/README.md Normal file
View File

@ -0,0 +1,10 @@
Documentation available at: https://huggingface.co/docs/text-generation-inference
## Release
When making a release, please update the latest version in the documentation with:
```
export OLD_VERSION="2\.0\.3"
export NEW_VERSION="2\.0\.4"
find . -name '*.md' -exec sed -i -e "s/$OLD_VERSION/$NEW_VERSION/g" {} \;
```

View File

@ -1121,6 +1121,15 @@
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.", "description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.",
"example": 0.95, "example": 0.95,
"nullable": true "nullable": true
},
"stop": {
"type": "array",
"items": {
"type": "string"
},
"description": "Up to 4 sequences where the API will stop generating further tokens.",
"example": "null",
"nullable": true
} }
} }
}, },

View File

@ -17,6 +17,8 @@
title: Supported Models and Hardware title: Supported Models and Hardware
- local: messages_api - local: messages_api
title: Messages API title: Messages API
- local: architecture
title: Internal Architecture
title: Getting started title: Getting started
- sections: - sections:
- local: basic_tutorials/consuming_tgi - local: basic_tutorials/consuming_tgi
@ -39,6 +41,8 @@
title: Visual Language Models title: Visual Language Models
- local: basic_tutorials/monitoring - local: basic_tutorials/monitoring
title: Monitoring TGI with Prometheus and Grafana title: Monitoring TGI with Prometheus and Grafana
- local: basic_tutorials/train_medusa
title: Train Medusa
title: Tutorials title: Tutorials
- sections: - sections:
- local: conceptual/streaming - local: conceptual/streaming

227
docs/source/architecture.md Normal file
View File

@ -0,0 +1,227 @@
# Text Generation Inference Architecture
This document aims at describing the architecture of Text Generation Inference (TGI), by describing the call flow between the separate components.
A high-level architecture diagram can be seen here:
![TGI architecture](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/TGI.png)
This diagram shows well there are these separate components:
- **The router**, also named `webserver`, that receives the client requests, buffers them, creates some batches, and prepares gRPC calls to a model server.
- **The model server**, responsible of receiving the gRPC requests and to process the inference on the model. If the model is sharded across multiple accelerators (e.g.: multiple GPUs), the model server shards might be synchronized via NCCL or equivalent.
- **The launcher** is a helper thar will be able to launch one or several model servers (if model is sharded), and it launches the router with the compatible arguments.
The router and the model server can be two different machines, they do not need to be deployed together.
## The Router
This component is a rust web server binary that accepts HTTP requests using the custom [HTTP API](https://huggingface.github.io/text-generation-inference/), as well as OpenAI's [Messages API](https://huggingface.co/docs/text-generation-inference/messages_api).
The router receives the API calls and handles the "baches" logic (and introduction to batching can be found [here](https://github.com/huggingface/text-generation-inference/blob/main/router/README.md)).
It uses different strategies to reduce latency between requests and responses, especially oriented to decoding latency. It will use queues, schedulers, and block allocators to achieve that and produce batched requests that it will then be sent to the model server.
### Router's command line
The router command line will be the way to pass parameters to it (it does not rely on configuration file):
```
Text Generation Webserver
Usage: text-generation-router [OPTIONS]
Options:
--max-concurrent-requests <MAX_CONCURRENT_REQUESTS>
[env: MAX_CONCURRENT_REQUESTS=] [default: 128]
--max-best-of <MAX_BEST_OF>
[env: MAX_BEST_OF=] [default: 2]
--max-stop-sequences <MAX_STOP_SEQUENCES>
[env: MAX_STOP_SEQUENCES=] [default: 4]
--max-top-n-tokens <MAX_TOP_N_TOKENS>
[env: MAX_TOP_N_TOKENS=] [default: 5]
--max-input-tokens <MAX_INPUT_TOKENS>
[env: MAX_INPUT_TOKENS=] [default: 1024]
--max-total-tokens <MAX_TOTAL_TOKENS>
[env: MAX_TOTAL_TOKENS=] [default: 2048]
--waiting-served-ratio <WAITING_SERVED_RATIO>
[env: WAITING_SERVED_RATIO=] [default: 1.2]
--max-batch-prefill-tokens <MAX_BATCH_PREFILL_TOKENS>
[env: MAX_BATCH_PREFILL_TOKENS=] [default: 4096]
--max-batch-total-tokens <MAX_BATCH_TOTAL_TOKENS>
[env: MAX_BATCH_TOTAL_TOKENS=]
--max-waiting-tokens <MAX_WAITING_TOKENS>
[env: MAX_WAITING_TOKENS=] [default: 20]
--max-batch-size <MAX_BATCH_SIZE>
[env: MAX_BATCH_SIZE=]
--hostname <HOSTNAME>
[env: HOSTNAME=] [default: 0.0.0.0]
-p, --port <PORT>
[env: PORT=] [default: 3000]
--master-shard-uds-path <MASTER_SHARD_UDS_PATH>
[env: MASTER_SHARD_UDS_PATH=] [default: /tmp/text-generation-server-0]
--tokenizer-name <TOKENIZER_NAME>
[env: TOKENIZER_NAME=] [default: bigscience/bloom]
--tokenizer-config-path <TOKENIZER_CONFIG_PATH>
[env: TOKENIZER_CONFIG_PATH=]
--revision <REVISION>
[env: REVISION=]
--validation-workers <VALIDATION_WORKERS>
[env: VALIDATION_WORKERS=] [default: 2]
--json-output
[env: JSON_OUTPUT=]
--otlp-endpoint <OTLP_ENDPOINT>
[env: OTLP_ENDPOINT=]
--cors-allow-origin <CORS_ALLOW_ORIGIN>
[env: CORS_ALLOW_ORIGIN=]
--ngrok
[env: NGROK=]
--ngrok-authtoken <NGROK_AUTHTOKEN>
[env: NGROK_AUTHTOKEN=]
--ngrok-edge <NGROK_EDGE>
[env: NGROK_EDGE=]
--messages-api-enabled
[env: MESSAGES_API_ENABLED=]
--disable-grammar-support
[env: DISABLE_GRAMMAR_SUPPORT=]
--max-client-batch-size <MAX_CLIENT_BATCH_SIZE>
[env: MAX_CLIENT_BATCH_SIZE=] [default: 4]
-h, --help
Print help
-V, --version
Print version
```
## The Model Server
The model server is a python server, capable of starting a server waiting for gRPC requests, loads a given model, perform sharding to provide [tensor parallelism](https://huggingface.co/docs/text-generation-inference/conceptual/tensor_parallelism), and stays alive while waiting for new requests.
The model server supports models instantiated using Pytorch and optimized for inference mainly on CUDA/ROCM.
### Model Server Variants
Several variants of the model server exist that are actively supported by Hugging Face:
- By default, the model server will attempt building [a server optimized for Nvidia GPUs with CUDA](https://huggingface.co/docs/text-generation-inference/installation_nvidia). The code for this version is hosted in the [main TGI repository](https://github.com/huggingface/text-generation-inference).
- A [version optimized for AMD with ROCm](https://huggingface.co/docs/text-generation-inference/installation_amd) is hosted in the main TGI repository. Some model features differ.
- The [version for Intel Gaudi](https://huggingface.co/docs/text-generation-inference/installation_gaudi) is maintained on a forked repository, often resynchronized with the main [TGI repository](https://github.com/huggingface/tgi-gaudi).
- A [version for Neuron (AWS Inferentia2)](https://huggingface.co/docs/text-generation-inference/installation_inferentia) is maintained as part of [Optimum Neuron](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference).
- A version for Google TPUs is maintained as part of [Optimum TPU](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference).
Not all variants provide the same features, as hardware and middleware capabilities do not provide the same optimizations.
### Command Line Interface
The official command line interface (CLI) for the server supports three subcommands, `download-weights`, `quantize` and `serve`:
- `download-weights` will download weights from the hub and, in some variants it will convert weights to a format that is adapted to the given implementation;
- `quantize` will allow to quantize a model using the `qptq` package. This feature is not available nor supported on all variants;
- `serve` will start the server that load a model (or a model shard), receives gRPC calls from the router, performs an inference and provides a formatted response to the given request.
Serve's command line parameters on the TGI repository are these:
```
Usage: cli.py serve [OPTIONS] MODEL_ID
╭─ Arguments ──────────────────────────────────────────────────────────────────────────────────────────────╮
│ * model_id TEXT [default: None] [required] │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯
╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────╮
│ --revision TEXT [default: None] │
│ --sharded --no-sharded [default: no-sharded] │
│ --quantize [bitsandbytes|bitsandbytes [default: None] │
│ -nf4|bitsandbytes-fp4|gptq │
│ |awq|eetq|exl2|fp8] │
│ --speculate INTEGER [default: None] │
│ --dtype [float16|bfloat16] [default: None] │
│ --trust-remote-code --no-trust-remote-code [default: │
│ no-trust-remote-code] │
│ --uds-path PATH [default: │
│ /tmp/text-generation-serve… │
│ --logger-level TEXT [default: INFO] │
│ --json-output --no-json-output [default: no-json-output] │
│ --otlp-endpoint TEXT [default: None] │
│ --help Show this message and exit. │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯
```
Note that some variants might support different parameters, and they could possibly accept more options that can be passed on using environment variables.
## Call Flow
Once both components are initialized, weights downloaded and model server is up and running, router and model server exchange data and info through the gRPC call. There are currently two supported schemas, [v2](https://github.com/huggingface/text-generation-inference/blob/main/proto/generate.proto) and [v3](https://github.com/huggingface/text-generation-inference/blob/main/proto/v3/generate.proto). These two versions are almost identical, except for:
- input chunks support, for text and image data,
- paged attention support
Here's a diagram that displays the exchanges that follow the router and model server startup.
```mermaid
sequenceDiagram
Router->>Model Server: service discovery
Model Server-->>Router: urls for other shards
Router->>Model Server: get model info
Model Server-->>Router: shard info
Router->>Model Server: health check
Model Server-->>Router: health OK
Router->>Model Server: warmup(max_input_tokens, max_batch_prefill_tokens, max_total_tokens, max_batch_size)
Model Server-->>Router: warmup result
```
After these are done, the router is ready to receive generate calls from multiple clients. Here's an example.
```mermaid
sequenceDiagram
participant Client 1
participant Client 2
participant Client 3
participant Router
participant Model Server
Client 1->>Router: generate_stream
Router->>Model Server: prefill(batch1)
Model Server-->>Router: generations, cached_batch1, timings
Router-->>Client 1: token 1
Router->>Model Server: decode(cached_batch1)
Model Server-->>Router: generations, cached_batch1, timings
Router-->>Client 1: token 2
Router->>Model Server: decode(cached_batch1)
Model Server-->>Router: generations, cached_batch1, timings
Router-->>Client 1: token 3
Client 2->>Router: generate_stream
Router->>Model Server: prefill(batch2)
Note right of Model Server: This stops previous batch, that is restarted
Model Server-->>Router: generations, cached_batch2, timings
Router-->>Client 2: token 1'
Router->>Model Server: decode(cached_batch1, cached_batch2)
Model Server-->>Router: generations, cached_batch1, timings
Router-->>Client 1: token 4
Router-->>Client 2: token 2'
Note left of Client 1: Client 1 leaves
Router->>Model Server: filter_batch(cached_batch1, request_ids_to_keep=batch2)
Model Server-->>Router: filtered batch
Router->>Model Server: decode(cached_batch2)
Model Server-->>Router: generations, cached_batch2, timings
Router-->>Client 2: token 3'
Client 3->>Router: generate_stream
Note right of Model Server: This stops previous batch, that is restarted
Router->>Model Server: prefill(batch3)
Note left of Client 1: Client 3 leaves without receiving any batch
Router->>Model Server: clear_cache(batch3)
Note right of Model Server: This stops previous batch, that is restarted
Router->>Model Server: decode(cached_batch3)
Note right of Model Server: Last token (stopping criteria)
Model Server-->>Router: generations, cached_batch3, timings
Router-->>Client 2: token 4'
```

View File

@ -19,6 +19,6 @@ docker run --gpus all \
--shm-size 1g \ --shm-size 1g \
-e HUGGING_FACE_HUB_TOKEN=$token \ -e HUGGING_FACE_HUB_TOKEN=$token \
-p 8080:80 \ -p 8080:80 \
-v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.3 \ -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.4 \
--model-id $model --model-id $model
``` ```

View File

@ -62,7 +62,9 @@ Options:
Possible values: Possible values:
- awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency - awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git> - eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
- exl2: Variable bit quantization. Requires a specific EXL2 quantized model: <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1)
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels - gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
- marlin: 4 bit quantization. Requires a specific Marlin quantized model: <https://hf.co/models?search=marlin>
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16 - bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model - bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
@ -141,7 +143,7 @@ Options:
## MAX_TOP_N_TOKENS ## MAX_TOP_N_TOKENS
```shell ```shell
--max-top-n-tokens <MAX_TOP_N_TOKENS> --max-top-n-tokens <MAX_TOP_N_TOKENS>
This is the maximum allowed value for clients to set `top_n_tokens`. `top_n_tokens is used to return information about the the `n` most likely tokens at each generation step, instead of just the sampled token. This information can be used for downstream tasks like for classification or ranking This is the maximum allowed value for clients to set `top_n_tokens`. `top_n_tokens` is used to return information about the the `n` most likely tokens at each generation step, instead of just the sampled token. This information can be used for downstream tasks like for classification or ranking
[env: MAX_TOP_N_TOKENS=] [env: MAX_TOP_N_TOKENS=]
[default: 5] [default: 5]

View File

@ -0,0 +1,208 @@
# Train Medusa
This tutorial will show you how to train a Medusa model on a dataset of your choice. Please check out the [speculation documentation](../conceptual/speculation) for more information on how Medusa works and speculation in general.
## What are the benefits of training a Medusa model?
Training Medusa heads can greatly improve the speed of generation. Medusa adds extra "heads" to LLMs to predict multiple future tokens simultaneously. When augmenting a model with Medusa, the original model stays untouched, and only the new heads are fine-tuned during training.
One of the most important things is to have a good dataset (with similar data to what will be used in production) because Medusa has a much higher hit-rate when the generation is in-domain.
If you train Medusa on a dataset that is very different from the one you will use in production then the model will not be able to predict the future tokens accurately and consequently the speedup will be minimal or non-existent.
## Self-distillation (Generating data for training)
There are many methods for preparing data for training, but one of the easiest and most effective ways is to "self-distill" the data. This means that you can use the same model to generate the data that you will use to train the model.
Essentially, you prompt the model with a similar input to what you will use in production and the model will generate the output.
We'll use this output to help train the medusa heads to predict the `n+1`, `n+2`, `n+3`, etc tokens in the sequence.
## Training
The original implementation of Medusa is available at [https://github.com/FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa) and we'll follow a very similar process to train the model as described on the original repository.
### Getting Started
There are two methods for training the model:
- `torchrun` that is a wrapper around `torch.distributed.launch`
- a forked version of `axlotl` that supports Medusa
In this tutorial we'll use `torchrun` to train the model as it is the most straightforward way to train the model but similar steps can be followed to train the model using `axlotl` if you prefer.
### Training with `torchrun`
```bash
mkdir medusa-training
cd medusa-training
pyenv install 3.10
pyenv local 3.10
uv venv -p 3.10
source .venv/bin/activate
```
Now lets clone the original `Medusa` repository and install the library.
```bash
git clone https://github.com/FasterDecoding/Medusa.git
cd Medusa
pip install -e .
```
Next we'll need some data to train on, we can use the `ShareGPT_Vicuna_unfiltered` dataset that is available on the Hugging Face Hub.
```bash
apt install git-lfs
git lfs install
git clone https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered
```
Currently our directory structure looks like this:
```bash
.
├── assets
├── CITATION.cff
├── create_data.py
├── data_generation
├── deepspeed.json
├── last_run_prepared
├── LICENSE
├── llm_judge
├── medusa
├── medusa_llm.egg-info
├── mistral.json
├── notebooks
├── pyproject.toml
├── README.md
├── ROADMAP.md
├── scripts
├── ShareGPT_Vicuna_unfiltered
│   ├── README.md
│   ├── ShareGPT_2023.05.04v0_Wasteland_Edition.json
│   └── ShareGPT_V4.3_unfiltered_cleaned_split.json
├── simple_gradio_interface.py
├── tiny-llama.json
└── vicuna_7b_qlora_stage1
```
## Start Training
Now the lets generate the data and start training the model. This process will take a while since we are generating data from the model.
First make sure you have an instance of TGI running with the model you want to use for self-distillation.
```bash
model=HuggingFaceH4/zephyr-7b-beta
volume=/home/ubuntu/.cache/huggingface/hub/
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model
```
Now we can generate the data using the `create_data.py` script.
```bash
python create_data.py \
--input-filename ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json \
--output-filename zephyr_self_distill.json
```
At this point our terminal should look like this:
<div class="flex justify-center">
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/medusa-train-large.gif"
width="550"
/>
</div>
> Note: In the screen shot above we are only using a the first 500 examples from the dataset to speed up the process, you should have a much larger dataset for training.
Now we can finally get to the fun part and start training the model!
Using `torchrun` we can easily launch the `medusa` training script with the `zephyr_self_distill.json` configuration file.
> NOTE: If you just self-distilled you may still have the model running, make sure to stop it before starting the training in order to allow all of the resources to be used for training.
```bash
WANDB_MODE=offline torchrun --nproc_per_node=4 medusa/train/train_legacy.py \
--model_name_or_path HuggingFaceH4/zephyr-7b-beta \
--data_path zephyr_self_distill.json \
--bf16 True \
--output_dir zephyr_out \
--num_train_epochs 5 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 4 \
--evaluation_strategy "no" \
--save_strategy "no" \
--learning_rate 1e-3 \
--weight_decay 0.0 \
--warmup_ratio 0.1 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--lazy_preprocess True \
--medusa_num_heads 3 \
--medusa_num_layers 1 \
--deepspeed deepspeed.json
```
<div class="flex justify-center">
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/medusa-train-heads-large.gif"
width="550"
/>
</div>
If successful, you should see the similar output to the one below:
```bash
wandb: Run history:
wandb: train/epoch ▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
wandb: train/global_step ▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
wandb: train/learning_rate ▅███▇▇▆▅▅▄▃▂▂▁▁▁
wandb: train/loss ██▆▄▄▃▃▂▂▃▁▁▂▁▁▁
wandb: train/medusa0_loss ▆▆▇▆▆▅▄▅▃▃▃▃▂▂▂▂▂▃▂▂▂▁▁▁▂▁▁▁▁▁█▁▁▁▂▁▁▁▁▁
wandb: train/medusa0_top1 ▁▁▁▁▁▁▁▁▃▂▃▃▄▄▄▃▄▃▄▄▅▅▆▅▆▆▇▅▇▇▄▇█▇▅▇█▆▇▇
wandb: train/medusa1_loss ▇▇█▇▇▆▅▅▃▄▃▃▃▃▃▃▃▃▃▃▂▁▂▂▂▁▁▂▁▁▇▁▁▁▂▁▁▁▁▁
wandb: train/medusa1_top1 ▁▁▁▁▁▁▁▁▃▂▃▃▃▄▄▃▃▂▃▃▅▅▆▄█▆▇▅▇▇▅█▇▇▅▇█▆▆▇
wandb: train/medusa2_loss ▃▃▄▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁█▁▁▁▂▁▁▁▁▁
wandb: train/medusa2_top1 ▁▁▁▂▁▁▁▁▂▂▃▃▃▄▄▃▃▂▃▃▅▆▅▄█▆▆▅▆▆▄█▇▇▄▇█▆▆▇
wandb: train/total_flos ▁
wandb: train/train_loss ▁
wandb: train/train_runtime ▁
wandb: train/train_samples_per_second ▁
wandb: train/train_steps_per_second ▁
wandb:
wandb: Run summary:
wandb: train/epoch 2.0
wandb: train/global_step 16
wandb: train/learning_rate 0.0
wandb: train/loss 14.8906
wandb: train/medusa0_loss 4.25
wandb: train/medusa0_top1 0.28809
wandb: train/medusa1_loss 4.8125
wandb: train/medusa1_top1 0.22727
wandb: train/medusa2_loss 5.5
wandb: train/medusa2_top1 0.17293
wandb: train/total_flos 0.0
wandb: train/train_loss 23.98242
wandb: train/train_runtime 396.9266
wandb: train/train_samples_per_second 2.519
wandb: train/train_steps_per_second 0.04
```
Last but most importantly, don't forget to push this model to the Hugging Face Hub so you can use it in your projects.
```bash
python -m medusa.hf_utils \
--folder zephyr_out_medusa_mlp_zephyr-7b-beta_medusa_3_lr_0.001_layers_1 \
--repo drbh/zephyr_medusa_demo
```
Woo, we've successfully trained a Medusa model and pushed it to the Hugging Face Hub! 🎉

View File

@ -2,11 +2,11 @@
## What is Guidance? ## What is Guidance?
Guidance is a feature that allows users to constrain the generation of a large language model with a specified grammar. This feature is particularly useful when you want to generate text that follows a specific structure or uses a specific set of words or produce output in a specific format. Guidance is a feature that allows users to constrain the generation of a large language model with a specified grammar. This feature is particularly useful when you want to generate text that follows a specific structure or uses a specific set of words or produce output in a specific format. A prominent example is JSON grammar, where the model is forced to output valid JSON.
## How is it used? ## How is it used?
Guidance can be in many ways and the community is always finding new ways to use it. Here are some examples of how you can use guidance: Guidance can be implemented in many ways and the community is always finding new ways to use it. Here are some examples of how you can use guidance:
Technically, guidance can be used to generate: Technically, guidance can be used to generate:

View File

@ -27,7 +27,7 @@ You can check a few existing fine-tunes for popular models:
- [text-generation-inference/Mistral-7B-Instruct-v0.2-medusa](https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa) - [text-generation-inference/Mistral-7B-Instruct-v0.2-medusa](https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa)
In order to create your own medusa heads for your own finetune, you should check own the original medusa repo. [https://github.com/FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa) In order to create your own medusa heads for your own finetune, you should check own the original medusa repo. [../basic_tutorials/train_medusa.md](../basic_tutorials/train_medusa.md)
In order to use medusa models in TGI, simply point to a medusa enabled model, and everything will load automatically. In order to use medusa models in TGI, simply point to a medusa enabled model, and everything will load automatically.

View File

@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--device=/dev/kfd --device=/dev/dri --group-add video \ --device=/dev/kfd --device=/dev/dri --group-add video \
--ipc=host --shm-size 256g --net host -v $volume:/data \ --ipc=host --shm-size 256g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.0.3-rocm \ ghcr.io/huggingface/text-generation-inference:2.0.4-rocm \
--model-id $model --model-id $model
``` ```
@ -27,7 +27,7 @@ TunableOp is enabled by default, the warmup may take 1-2 minutes. In case you wo
## Flash attention implementation ## Flash attention implementation
Two implementations of Flash Attention are available for ROCm, the first is [ROCm/flash-attention](https://github.com/ROCm/flash-attention) based on a [Composable Kernel](https://github.com/ROCm/composable_kernel) (CK) implementation, and the second is a [Triton implementation](https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/utils/flash_attn_triton.py). Two implementations of Flash Attention are available for ROCm, the first is [ROCm/flash-attention](https://github.com/ROCm/flash-attention) based on a [Composable Kernel](https://github.com/ROCm/composable_kernel) (CK) implementation, and the second is a [Triton implementation](https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/layers/attention/flash_attn_triton.py).
By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container. By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container.

View File

@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \ docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.0.3 \ ghcr.io/huggingface/text-generation-inference:2.0.4 \
--model-id $model --model-id $model
``` ```

View File

@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.0.3 \ ghcr.io/huggingface/text-generation-inference:2.0.4 \
--model-id $model --model-id $model
``` ```
@ -88,7 +88,7 @@ curl 127.0.0.1:8080/generate \
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more. To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
```bash ```bash
docker run ghcr.io/huggingface/text-generation-inference:2.0.3 --help docker run ghcr.io/huggingface/text-generation-inference:2.0.4 --help
``` ```
</Tip> </Tip>

View File

@ -20,7 +20,7 @@ Text Generation Inference enables serving optimized models on specific hardware
- [Baichuan](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat) - [Baichuan](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat)
- [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct) - [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct)
- [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1) - [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1)
- [Qwen 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1) - [Qwen 2](https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f)
- [Opt](https://huggingface.co/facebook/opt-6.7b) - [Opt](https://huggingface.co/facebook/opt-6.7b)
- [T5](https://huggingface.co/google/flan-t5-xxl) - [T5](https://huggingface.co/google/flan-t5-xxl)
- [Galactica](https://huggingface.co/facebook/galactica-120b) - [Galactica](https://huggingface.co/facebook/galactica-120b)

View File

@ -7,9 +7,10 @@ import os
import docker import docker
import json import json
import math import math
import shutil
import tempfile
import time import time
import random import random
import re
from docker.errors import NotFound from docker.errors import NotFound
from typing import Optional, List, Dict from typing import Optional, List, Dict
@ -37,6 +38,7 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
class ResponseComparator(JSONSnapshotExtension): class ResponseComparator(JSONSnapshotExtension):
rtol = 0.2 rtol = 0.2
ignore_logprob = False
def serialize( def serialize(
self, self,
@ -94,7 +96,10 @@ class ResponseComparator(JSONSnapshotExtension):
return ( return (
token.id == other.id token.id == other.id
and token.text == other.text and token.text == other.text
and math.isclose(token.logprob, other.logprob, rel_tol=self.rtol) and (
self.ignore_logprob
or math.isclose(token.logprob, other.logprob, rel_tol=self.rtol)
)
and token.special == other.special and token.special == other.special
) )
@ -104,8 +109,11 @@ class ResponseComparator(JSONSnapshotExtension):
prefill_token.id == other.id prefill_token.id == other.id
and prefill_token.text == other.text and prefill_token.text == other.text
and ( and (
math.isclose( self.ignore_logprob
prefill_token.logprob, other.logprob, rel_tol=self.rtol or math.isclose(
prefill_token.logprob,
other.logprob,
rel_tol=self.rtol,
) )
if prefill_token.logprob is not None if prefill_token.logprob is not None
else prefill_token.logprob == other.logprob else prefill_token.logprob == other.logprob
@ -222,6 +230,10 @@ class GenerousResponseComparator(ResponseComparator):
rtol = 0.75 rtol = 0.75
class IgnoreLogProbResponseComparator(ResponseComparator):
ignore_logprob = True
class LauncherHandle: class LauncherHandle:
def __init__(self, port: int): def __init__(self, port: int):
self.client = AsyncClient(f"http://localhost:{port}") self.client = AsyncClient(f"http://localhost:{port}")
@ -273,6 +285,11 @@ def generous_response_snapshot(snapshot):
return snapshot.use_extension(GenerousResponseComparator) return snapshot.use_extension(GenerousResponseComparator)
@pytest.fixture
def ignore_logprob_response_snapshot(snapshot):
return snapshot.use_extension(IgnoreLogProbResponseComparator)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def event_loop(): def event_loop():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -347,19 +364,22 @@ def launcher(event_loop):
if not use_flash_attention: if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false" env["USE_FLASH_ATTENTION"] = "false"
with tempfile.TemporaryFile("w+") as tmp:
# We'll output stdout/stderr to a temporary file. Using a pipe
# cause the process to block until stdout is read.
with subprocess.Popen( with subprocess.Popen(
args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env args,
stdout=tmp,
stderr=subprocess.STDOUT,
env=env,
) as process: ) as process:
yield ProcessLauncherHandle(process, port) yield ProcessLauncherHandle(process, port)
process.terminate() process.terminate()
process.wait(60) process.wait(60)
launcher_output = process.stdout.read().decode("utf-8") tmp.seek(0)
print(launcher_output, file=sys.stderr) shutil.copyfileobj(tmp, sys.stderr)
process.stdout.close()
process.stderr.close()
if not use_flash_attention: if not use_flash_attention:
del env["USE_FLASH_ATTENTION"] del env["USE_FLASH_ATTENTION"]

View File

@ -5,7 +5,7 @@
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"message": { "message": {
"content": "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally", "content": "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to explore in the middle of urban confines. In fact, typical times for humidity levels in Brooklyn include:\n\n- Early morning: 80-85% humidity, with occas",
"name": null, "name": null,
"role": "assistant", "role": "assistant",
"tool_calls": null "tool_calls": null
@ -13,14 +13,14 @@
"usage": null "usage": null
} }
], ],
"created": 1712874856, "created": 1716553098,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native", "system_fingerprint": "2.0.5-dev0-native",
"usage": { "usage": {
"completion_tokens": 100, "completion_tokens": 100,
"prompt_tokens": 60, "prompt_tokens": 62,
"total_tokens": 160 "total_tokens": 162
} }
} }

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -9.640625,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 604,
"logprob": -2.4296875,
"special": false,
"text": " for"
},
{
"id": 573,
"logprob": -2.4453125,
"special": false,
"text": " the"
},
{
"id": 2412,
"logprob": -2.8632812,
"special": false,
"text": " following"
},
{
"id": 235292,
"logprob": -2.1328125,
"special": false,
"text": ":"
},
{
"id": 109,
"logprob": -0.76660156,
"special": false,
"text": "\n\n"
},
{
"id": 235287,
"logprob": -1.3837891,
"special": false,
"text": "*"
},
{
"id": 235248,
"logprob": -1.9746094,
"special": false,
"text": " "
},
{
"id": 199,
"logprob": -1.4189453,
"special": false,
"text": "<strong>"
},
{
"id": 1232,
"logprob": -4.34375,
"special": false,
"text": "Name"
},
{
"id": 208,
"logprob": -0.8852539,
"special": false,
"text": "</strong>"
}
],
"top_tokens": null
},
"generated_text": " for the following:\n\n* <strong>Name</strong>"
}

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -9.65625,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.3671875,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 604,
"logprob": -0.36938477,
"special": false,
"text": " for"
},
{
"id": 235248,
"logprob": -1.8046875,
"special": false,
"text": " "
},
{
"id": 235274,
"logprob": -0.46240234,
"special": false,
"text": "1"
},
{
"id": 235284,
"logprob": -1.7460938,
"special": false,
"text": "2"
},
{
"id": 235265,
"logprob": -1.9443359,
"special": false,
"text": "."
},
{
"id": 235284,
"logprob": -1.4550781,
"special": false,
"text": "2"
},
{
"id": 235308,
"logprob": -1.0205078,
"special": false,
"text": "5"
},
{
"id": 235290,
"logprob": -1.0283203,
"special": false,
"text": "-"
},
{
"id": 235274,
"logprob": -1.2783203,
"special": false,
"text": "1"
},
{
"id": 235284,
"logprob": 0.0,
"special": false,
"text": "2"
}
],
"top_tokens": null
},
"generated_text": "Test request for 12.25-12"
}

View File

@ -0,0 +1,358 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -9.6484375,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.359375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 604,
"logprob": -2.4277344,
"special": false,
"text": " for"
},
{
"id": 573,
"logprob": -2.4394531,
"special": false,
"text": " the"
},
{
"id": 2412,
"logprob": -2.8613281,
"special": false,
"text": " following"
},
{
"id": 235292,
"logprob": -2.1523438,
"special": false,
"text": ":"
},
{
"id": 109,
"logprob": -0.76220703,
"special": false,
"text": "\n\n"
},
{
"id": 235287,
"logprob": -1.3642578,
"special": false,
"text": "*"
},
{
"id": 235248,
"logprob": -2.0175781,
"special": false,
"text": " "
},
{
"id": 199,
"logprob": -1.4238281,
"special": false,
"text": "<strong>"
},
{
"id": 1232,
"logprob": -4.328125,
"special": false,
"text": "Name"
},
{
"id": 208,
"logprob": -0.8881836,
"special": false,
"text": "</strong>"
}
],
"top_tokens": null
},
"generated_text": " for the following:\n\n* <strong>Name</strong>"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -9.6484375,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 604,
"logprob": -2.4238281,
"special": false,
"text": " for"
},
{
"id": 573,
"logprob": -2.4453125,
"special": false,
"text": " the"
},
{
"id": 2412,
"logprob": -2.859375,
"special": false,
"text": " following"
},
{
"id": 235292,
"logprob": -2.1445312,
"special": false,
"text": ":"
},
{
"id": 109,
"logprob": -0.7631836,
"special": false,
"text": "\n\n"
},
{
"id": 235287,
"logprob": -1.3642578,
"special": false,
"text": "*"
},
{
"id": 235248,
"logprob": -1.9960938,
"special": false,
"text": " "
},
{
"id": 199,
"logprob": -1.4179688,
"special": false,
"text": "<strong>"
},
{
"id": 1232,
"logprob": -4.3359375,
"special": false,
"text": "Name"
},
{
"id": 208,
"logprob": -0.8847656,
"special": false,
"text": "</strong>"
}
],
"top_tokens": null
},
"generated_text": " for the following:\n\n* <strong>Name</strong>"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -9.640625,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.3671875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 604,
"logprob": -2.4257812,
"special": false,
"text": " for"
},
{
"id": 573,
"logprob": -2.4453125,
"special": false,
"text": " the"
},
{
"id": 2412,
"logprob": -2.8789062,
"special": false,
"text": " following"
},
{
"id": 235292,
"logprob": -2.1367188,
"special": false,
"text": ":"
},
{
"id": 109,
"logprob": -0.76171875,
"special": false,
"text": "\n\n"
},
{
"id": 235287,
"logprob": -1.3515625,
"special": false,
"text": "*"
},
{
"id": 235248,
"logprob": -1.9873047,
"special": false,
"text": " "
},
{
"id": 199,
"logprob": -1.4169922,
"special": false,
"text": "<strong>"
},
{
"id": 1232,
"logprob": -4.3320312,
"special": false,
"text": "Name"
},
{
"id": 208,
"logprob": -0.8930664,
"special": false,
"text": "</strong>"
}
],
"top_tokens": null
},
"generated_text": " for the following:\n\n* <strong>Name</strong>"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -9.6484375,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.359375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 604,
"logprob": -2.4179688,
"special": false,
"text": " for"
},
{
"id": 573,
"logprob": -2.4492188,
"special": false,
"text": " the"
},
{
"id": 2412,
"logprob": -2.8574219,
"special": false,
"text": " following"
},
{
"id": 235292,
"logprob": -2.1445312,
"special": false,
"text": ":"
},
{
"id": 109,
"logprob": -0.7519531,
"special": false,
"text": "\n\n"
},
{
"id": 235287,
"logprob": -1.3623047,
"special": false,
"text": "*"
},
{
"id": 235248,
"logprob": -1.9707031,
"special": false,
"text": " "
},
{
"id": 199,
"logprob": -1.4267578,
"special": false,
"text": "<strong>"
},
{
"id": 1232,
"logprob": -4.3359375,
"special": false,
"text": "Name"
},
{
"id": 208,
"logprob": -0.88427734,
"special": false,
"text": "</strong>"
}
],
"top_tokens": null
},
"generated_text": " for the following:\n\n* <strong>Name</strong>"
}
]

View File

@ -0,0 +1,84 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.4375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.9316406,
"special": false,
"text": ":"
},
{
"id": 330,
"logprob": -3.5136719,
"special": false,
"text": " \""
},
{
"id": 489,
"logprob": -0.7783203,
"special": false,
"text": " +"
},
{
"id": 1715,
"logprob": -1.2314453,
"special": false,
"text": " request"
},
{
"id": 489,
"logprob": -2.0019531,
"special": false,
"text": " +"
},
{
"id": 2990,
"logprob": -1.5009766,
"special": false,
"text": " \"\\"
},
{
"id": 77,
"logprob": -0.057434082,
"special": false,
"text": "n"
},
{
"id": 702,
"logprob": -1.4912109,
"special": false,
"text": "\"\n"
},
{
"id": 262,
"logprob": -1.2636719,
"special": false,
"text": " "
},
{
"id": 557,
"logprob": -2.4042969,
"special": false,
"text": " }\n\n"
}
],
"top_tokens": null
},
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
}

View File

@ -0,0 +1,84 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.453125,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 13,
"logprob": -1.9980469,
"special": false,
"text": "."
},
{
"id": 578,
"logprob": -0.15795898,
"special": false,
"text": " The"
},
{
"id": 3622,
"logprob": -1.0458984,
"special": false,
"text": " server"
},
{
"id": 31680,
"logprob": -1.3623047,
"special": false,
"text": " responds"
},
{
"id": 449,
"logprob": 0.0,
"special": false,
"text": " with"
},
{
"id": 264,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 330,
"logprob": -0.5678711,
"special": false,
"text": " \""
},
{
"id": 1049,
"logprob": -0.12322998,
"special": false,
"text": "200"
},
{
"id": 10619,
"logprob": 0.0,
"special": false,
"text": " OK"
},
{
"id": 1,
"logprob": 0.0,
"special": false,
"text": "\""
}
],
"top_tokens": null
},
"generated_text": "Test request. The server responds with a \"200 OK\""
}

View File

@ -0,0 +1,338 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.453125,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.9785156,
"special": false,
"text": ":"
},
{
"id": 330,
"logprob": -3.4941406,
"special": false,
"text": " \""
},
{
"id": 489,
"logprob": -0.79345703,
"special": false,
"text": " +"
},
{
"id": 1715,
"logprob": -1.2324219,
"special": false,
"text": " request"
},
{
"id": 489,
"logprob": -1.9794922,
"special": false,
"text": " +"
},
{
"id": 2990,
"logprob": -1.4892578,
"special": false,
"text": " \"\\"
},
{
"id": 77,
"logprob": -0.058258057,
"special": false,
"text": "n"
},
{
"id": 702,
"logprob": -1.4892578,
"special": false,
"text": "\"\n"
},
{
"id": 262,
"logprob": -1.2783203,
"special": false,
"text": " "
},
{
"id": 557,
"logprob": -2.3945312,
"special": false,
"text": " }\n\n"
}
],
"top_tokens": null
},
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.40625,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.9433594,
"special": false,
"text": ":"
},
{
"id": 330,
"logprob": -3.4726562,
"special": false,
"text": " \""
},
{
"id": 489,
"logprob": -0.8022461,
"special": false,
"text": " +"
},
{
"id": 1715,
"logprob": -1.2509766,
"special": false,
"text": " request"
},
{
"id": 489,
"logprob": -1.984375,
"special": false,
"text": " +"
},
{
"id": 2990,
"logprob": -1.4677734,
"special": false,
"text": " \"\\"
},
{
"id": 77,
"logprob": -0.059173584,
"special": false,
"text": "n"
},
{
"id": 702,
"logprob": -1.4990234,
"special": false,
"text": "\"\n"
},
{
"id": 262,
"logprob": -1.2822266,
"special": false,
"text": " "
},
{
"id": 557,
"logprob": -2.3867188,
"special": false,
"text": " }\n\n"
}
],
"top_tokens": null
},
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.421875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.9511719,
"special": false,
"text": ":"
},
{
"id": 330,
"logprob": -3.46875,
"special": false,
"text": " \""
},
{
"id": 489,
"logprob": -0.77490234,
"special": false,
"text": " +"
},
{
"id": 1715,
"logprob": -1.2558594,
"special": false,
"text": " request"
},
{
"id": 489,
"logprob": -1.984375,
"special": false,
"text": " +"
},
{
"id": 2990,
"logprob": -1.4990234,
"special": false,
"text": " \"\\"
},
{
"id": 77,
"logprob": -0.059143066,
"special": false,
"text": "n"
},
{
"id": 702,
"logprob": -1.4941406,
"special": false,
"text": "\"\n"
},
{
"id": 262,
"logprob": -1.2578125,
"special": false,
"text": " "
},
{
"id": 557,
"logprob": -2.3964844,
"special": false,
"text": " }\n\n"
}
],
"top_tokens": null
},
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.4140625,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.9101562,
"special": false,
"text": ":"
},
{
"id": 330,
"logprob": -3.5039062,
"special": false,
"text": " \""
},
{
"id": 489,
"logprob": -0.8076172,
"special": false,
"text": " +"
},
{
"id": 1715,
"logprob": -1.2236328,
"special": false,
"text": " request"
},
{
"id": 489,
"logprob": -1.9853516,
"special": false,
"text": " +"
},
{
"id": 2990,
"logprob": -1.4892578,
"special": false,
"text": " \"\\"
},
{
"id": 77,
"logprob": -0.056671143,
"special": false,
"text": "n"
},
{
"id": 702,
"logprob": -1.5107422,
"special": false,
"text": "\"\n"
},
{
"id": 262,
"logprob": -1.2597656,
"special": false,
"text": " "
},
{
"id": 557,
"logprob": -2.4042969,
"special": false,
"text": " }\n\n"
}
],
"top_tokens": null
},
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
}
]

View File

@ -0,0 +1,84 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6230469,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.046875,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1425781,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.9238281,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.076660156,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10821533,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
}

View File

@ -0,0 +1,84 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 13,
"logprob": -2.2539062,
"special": false,
"text": "."
},
{
"id": 578,
"logprob": -0.15563965,
"special": false,
"text": " The"
},
{
"id": 3622,
"logprob": -0.8203125,
"special": false,
"text": " server"
},
{
"id": 706,
"logprob": 0.0,
"special": false,
"text": " has"
},
{
"id": 539,
"logprob": 0.0,
"special": false,
"text": " not"
},
{
"id": 3686,
"logprob": 0.0,
"special": false,
"text": " yet"
},
{
"id": 3288,
"logprob": 0.0,
"special": false,
"text": " sent"
},
{
"id": 904,
"logprob": 0.0,
"special": false,
"text": " any"
},
{
"id": 828,
"logprob": 0.0,
"special": false,
"text": " data"
},
{
"id": 382,
"logprob": -1.5517578,
"special": false,
"text": ".\n\n"
}
],
"top_tokens": null
},
"generated_text": "Test request. The server has not yet sent any data.\n\n"
}

View File

@ -0,0 +1,338 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
}
]

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -12.390625,
"text": "Test"
},
{
"id": 2009,
"logprob": -11.0625,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.0507812,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -2.3007812,
"special": false,
"text": "\n"
},
{
"id": 29902,
"logprob": -2.0449219,
"special": false,
"text": "I"
},
{
"id": 505,
"logprob": -1.3242188,
"special": false,
"text": " have"
},
{
"id": 263,
"logprob": -0.2076416,
"special": false,
"text": " a"
},
{
"id": 1243,
"logprob": -2.0273438,
"special": false,
"text": " test"
},
{
"id": 2009,
"logprob": -0.6845703,
"special": false,
"text": " request"
},
{
"id": 515,
"logprob": -1.1748047,
"special": false,
"text": " from"
},
{
"id": 263,
"logprob": -1.0644531,
"special": false,
"text": " a"
},
{
"id": 1404,
"logprob": -1.5224609,
"special": false,
"text": " user"
}
],
"top_tokens": null
},
"generated_text": "\n\nI have a test request from a user"
}

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -12.390625,
"text": "Test"
},
{
"id": 2009,
"logprob": -11.0625,
"text": "request"
}
],
"seed": 0,
"tokens": [
{
"id": 5229,
"logprob": -1.2607422,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": 0.0,
"special": false,
"text": ":"
},
{
"id": 6527,
"logprob": -0.11450195,
"special": false,
"text": " Could"
},
{
"id": 451,
"logprob": 0.0,
"special": false,
"text": " not"
},
{
"id": 4511,
"logprob": -0.2286377,
"special": false,
"text": " connect"
},
{
"id": 304,
"logprob": 0.0,
"special": false,
"text": " to"
},
{
"id": 1923,
"logprob": -1.2568359,
"special": false,
"text": " server"
},
{
"id": 13,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.15905762,
"special": false,
"text": "\n"
},
{
"id": 29902,
"logprob": -0.21618652,
"special": false,
"text": "I"
}
],
"top_tokens": null
},
"generated_text": "Test request failed: Could not connect to server\n\nI"
}

View File

@ -0,0 +1,358 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -12.390625,
"text": "Test"
},
{
"id": 2009,
"logprob": -11.0625,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.0507812,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -2.3007812,
"special": false,
"text": "\n"
},
{
"id": 29902,
"logprob": -2.0449219,
"special": false,
"text": "I"
},
{
"id": 505,
"logprob": -1.3242188,
"special": false,
"text": " have"
},
{
"id": 263,
"logprob": -0.2076416,
"special": false,
"text": " a"
},
{
"id": 1243,
"logprob": -2.0273438,
"special": false,
"text": " test"
},
{
"id": 2009,
"logprob": -0.6845703,
"special": false,
"text": " request"
},
{
"id": 515,
"logprob": -1.1748047,
"special": false,
"text": " from"
},
{
"id": 263,
"logprob": -1.0595703,
"special": false,
"text": " a"
},
{
"id": 1404,
"logprob": -1.5224609,
"special": false,
"text": " user"
}
],
"top_tokens": null
},
"generated_text": "\n\nI have a test request from a user"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -12.390625,
"text": "Test"
},
{
"id": 2009,
"logprob": -11.0625,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.0507812,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -2.3007812,
"special": false,
"text": "\n"
},
{
"id": 29902,
"logprob": -2.0449219,
"special": false,
"text": "I"
},
{
"id": 505,
"logprob": -1.3242188,
"special": false,
"text": " have"
},
{
"id": 263,
"logprob": -0.2076416,
"special": false,
"text": " a"
},
{
"id": 1243,
"logprob": -2.0273438,
"special": false,
"text": " test"
},
{
"id": 2009,
"logprob": -0.6845703,
"special": false,
"text": " request"
},
{
"id": 515,
"logprob": -1.1748047,
"special": false,
"text": " from"
},
{
"id": 263,
"logprob": -1.0595703,
"special": false,
"text": " a"
},
{
"id": 1404,
"logprob": -1.5224609,
"special": false,
"text": " user"
}
],
"top_tokens": null
},
"generated_text": "\n\nI have a test request from a user"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -12.390625,
"text": "Test"
},
{
"id": 2009,
"logprob": -11.0625,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.0507812,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -2.3007812,
"special": false,
"text": "\n"
},
{
"id": 29902,
"logprob": -2.0449219,
"special": false,
"text": "I"
},
{
"id": 505,
"logprob": -1.3242188,
"special": false,
"text": " have"
},
{
"id": 263,
"logprob": -0.2076416,
"special": false,
"text": " a"
},
{
"id": 1243,
"logprob": -2.0273438,
"special": false,
"text": " test"
},
{
"id": 2009,
"logprob": -0.6845703,
"special": false,
"text": " request"
},
{
"id": 515,
"logprob": -1.1748047,
"special": false,
"text": " from"
},
{
"id": 263,
"logprob": -1.0595703,
"special": false,
"text": " a"
},
{
"id": 1404,
"logprob": -1.5224609,
"special": false,
"text": " user"
}
],
"top_tokens": null
},
"generated_text": "\n\nI have a test request from a user"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -12.390625,
"text": "Test"
},
{
"id": 2009,
"logprob": -11.0625,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.0507812,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -2.3007812,
"special": false,
"text": "\n"
},
{
"id": 29902,
"logprob": -2.0449219,
"special": false,
"text": "I"
},
{
"id": 505,
"logprob": -1.3242188,
"special": false,
"text": " have"
},
{
"id": 263,
"logprob": -0.2076416,
"special": false,
"text": " a"
},
{
"id": 1243,
"logprob": -2.0273438,
"special": false,
"text": " test"
},
{
"id": 2009,
"logprob": -0.6845703,
"special": false,
"text": " request"
},
{
"id": 515,
"logprob": -1.1748047,
"special": false,
"text": " from"
},
{
"id": 263,
"logprob": -1.0595703,
"special": false,
"text": " a"
},
{
"id": 1404,
"logprob": -1.5224609,
"special": false,
"text": " user"
}
],
"top_tokens": null
},
"generated_text": "\n\nI have a test request from a user"
}
]

View File

@ -0,0 +1,61 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 8,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 2502,
"logprob": -1.734375,
"special": false,
"text": "image"
},
{
"id": 2196,
"logprob": -0.5756836,
"special": false,
"text": " result"
},
{
"id": 604,
"logprob": -0.007843018,
"special": false,
"text": " for"
},
{
"id": 12254,
"logprob": -1.7167969,
"special": false,
"text": " chicken"
},
{
"id": 611,
"logprob": -0.17053223,
"special": false,
"text": " on"
},
{
"id": 573,
"logprob": -0.7626953,
"special": false,
"text": " the"
},
{
"id": 8318,
"logprob": -0.02709961,
"special": false,
"text": " beach"
},
{
"id": 1,
"logprob": -0.20739746,
"special": true,
"text": "<eos>"
}
],
"top_tokens": null
},
"generated_text": "image result for chicken on the beach"
}

View File

@ -0,0 +1,23 @@
{
"choices": [
{
"finish_reason": "eos_token",
"index": 0,
"logprobs": null,
"message": {
"content": "{\n \"temperature\": [\n 35,\n 34,\n 36\n ],\n \"unit\": \"°c\"\n}",
"role": "assistant"
}
}
],
"created": 1718044128,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.5-dev0-native",
"usage": {
"completion_tokens": 39,
"prompt_tokens": 136,
"total_tokens": 175
}
}

View File

@ -0,0 +1,85 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 12,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 450,
"logprob": -0.26342773,
"special": false,
"text": " The"
},
{
"id": 21282,
"logprob": -0.01838684,
"special": false,
"text": " cow"
},
{
"id": 322,
"logprob": -0.18041992,
"special": false,
"text": " and"
},
{
"id": 521,
"logprob": -0.62841797,
"special": false,
"text": " ch"
},
{
"id": 21475,
"logprob": -0.0037956238,
"special": false,
"text": "icken"
},
{
"id": 526,
"logprob": -0.018737793,
"special": false,
"text": " are"
},
{
"id": 373,
"logprob": -1.0820312,
"special": false,
"text": " on"
},
{
"id": 263,
"logprob": -0.5083008,
"special": false,
"text": " a"
},
{
"id": 25695,
"logprob": -0.07128906,
"special": false,
"text": " beach"
},
{
"id": 29889,
"logprob": -0.12573242,
"special": false,
"text": "."
},
{
"id": 32002,
"logprob": -0.0029792786,
"special": true,
"text": "<end_of_utterance>"
},
{
"id": 2,
"logprob": -0.00024962425,
"special": true,
"text": "</s>"
}
],
"top_tokens": null
},
"generated_text": " The cow and chicken are on a beach."
}

View File

@ -0,0 +1,133 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 20,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 415,
"logprob": -0.04421997,
"special": false,
"text": " The"
},
{
"id": 12072,
"logprob": -0.13500977,
"special": false,
"text": " cow"
},
{
"id": 349,
"logprob": -0.06750488,
"special": false,
"text": " is"
},
{
"id": 6328,
"logprob": -0.6352539,
"special": false,
"text": " standing"
},
{
"id": 356,
"logprob": -0.16186523,
"special": false,
"text": " on"
},
{
"id": 272,
"logprob": -0.5078125,
"special": false,
"text": " the"
},
{
"id": 10305,
"logprob": -0.017913818,
"special": false,
"text": " beach"
},
{
"id": 304,
"logprob": -1.5205078,
"special": false,
"text": " and"
},
{
"id": 272,
"logprob": -0.029174805,
"special": false,
"text": " the"
},
{
"id": 13088,
"logprob": -0.003479004,
"special": false,
"text": " chicken"
},
{
"id": 349,
"logprob": -0.0035095215,
"special": false,
"text": " is"
},
{
"id": 6398,
"logprob": -0.3088379,
"special": false,
"text": " sitting"
},
{
"id": 356,
"logprob": -0.027755737,
"special": false,
"text": " on"
},
{
"id": 264,
"logprob": -0.31884766,
"special": false,
"text": " a"
},
{
"id": 17972,
"logprob": -0.047943115,
"special": false,
"text": " pile"
},
{
"id": 302,
"logprob": -0.0002925396,
"special": false,
"text": " of"
},
{
"id": 2445,
"logprob": -0.02935791,
"special": false,
"text": " money"
},
{
"id": 28723,
"logprob": -0.031219482,
"special": false,
"text": "."
},
{
"id": 32002,
"logprob": -0.00034475327,
"special": true,
"text": "<end_of_utterance>"
},
{
"id": 2,
"logprob": -1.1920929e-07,
"special": true,
"text": "</s>"
}
],
"top_tokens": null
},
"generated_text": " The cow is standing on the beach and the chicken is sitting on a pile of money."
}

View File

@ -35,8 +35,9 @@ async def test_flash_llama_simple(flash_llama_chat, response_snapshot):
], ],
) )
print(repr(response.choices[0].message.content))
assert ( assert (
response.choices[0].message.content response.choices[0].message.content
== "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally" == "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to explore in the middle of urban confines. In fact, typical times for humidity levels in Brooklyn include:\n\n- Early morning: 80-85% humidity, with occas"
) )
assert response == response_snapshot assert response == response_snapshot

View File

@ -3,7 +3,7 @@ import pytest
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_gemma_handle(launcher): def flash_gemma_handle(launcher):
with launcher("gg-hf/gemma-2b", num_shard=1) as handle: with launcher("google/gemma-2b", num_shard=1) as handle:
yield handle yield handle
@ -13,7 +13,6 @@ async def flash_gemma(flash_gemma_handle):
return flash_gemma_handle.client return flash_gemma_handle.client
@pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma(flash_gemma, response_snapshot): async def test_flash_gemma(flash_gemma, response_snapshot):
@ -25,7 +24,6 @@ async def test_flash_gemma(flash_gemma, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma_all_params(flash_gemma, response_snapshot): async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
@ -49,7 +47,6 @@ async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot): async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot):

View File

@ -0,0 +1,64 @@
import pytest
@pytest.fixture(scope="module")
def flash_gemma_gptq_handle(launcher):
with launcher("TechxGenus/gemma-2b-GPTQ", num_shard=1, quantize="gptq") as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_gemma_gptq(flash_gemma_gptq_handle):
await flash_gemma_gptq_handle.health(300)
return flash_gemma_gptq_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapshot):
response = await flash_gemma_gptq.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == ignore_logprob_response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_gptq_all_params(
flash_gemma_gptq, ignore_logprob_response_snapshot
):
response = await flash_gemma_gptq.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == ignore_logprob_response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_gptq_load(
flash_gemma_gptq, generate_load, ignore_logprob_response_snapshot
):
responses = await generate_load(
flash_gemma_gptq, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == ignore_logprob_response_snapshot

View File

@ -0,0 +1,73 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_exl2_handle(launcher):
with launcher(
"turboderp/Llama-3-8B-Instruct-exl2",
revision="2.5bpw",
# Set max input length to avoid OOM due to extremely large
# scratch buffer.
max_input_length=1024,
num_shard=1,
quantize="exl2",
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_exl2(flash_llama_exl2_handle):
await flash_llama_exl2_handle.health(300)
return flash_llama_exl2_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
response = await flash_llama_exl2.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == ignore_logprob_response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_exl2_all_params(
flash_llama_exl2, ignore_logprob_response_snapshot
):
response = await flash_llama_exl2.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert (
response.generated_text == 'Test request. The server responds with a "200 OK"'
)
assert response == ignore_logprob_response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_exl2_load(
flash_llama_exl2, generate_load, ignore_logprob_response_snapshot
):
responses = await generate_load(
flash_llama_exl2, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == ignore_logprob_response_snapshot

View File

@ -0,0 +1,65 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_gptq_marlin_handle(launcher):
with launcher(
"astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="marlin"
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle):
await flash_llama_gptq_marlin_handle.health(300)
return flash_llama_gptq_marlin_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot):
response = await flash_llama_gptq_marlin.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_marlin_all_params(
flash_llama_gptq_marlin, response_snapshot
):
response = await flash_llama_gptq_marlin.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_marlin_load(
flash_llama_gptq_marlin, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_gptq_marlin, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -0,0 +1,63 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_marlin_handle(launcher):
with launcher(
"neuralmagic/llama-2-7b-chat-marlin", num_shard=2, quantize="marlin"
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_marlin(flash_llama_marlin_handle):
await flash_llama_marlin_handle.health(300)
return flash_llama_marlin_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
response = await flash_llama_marlin.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapshot):
response = await flash_llama_marlin.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin_load(
flash_llama_marlin, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_marlin, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -22,6 +22,12 @@ async def flash_pali_gemma(flash_pali_gemma_handle):
return flash_pali_gemma_handle.client return flash_pali_gemma_handle.client
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach(): def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file: with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()) encoded_string = base64.b64encode(image_file.read())
@ -37,3 +43,20 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
assert response.generated_text == "beach" assert response.generated_text == "beach"
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot):
chicken = get_chicken()
cow_beach = get_cow_beach()
response = await flash_pali_gemma.generate(
f"caption![]({chicken})![]({cow_beach})\n",
max_new_tokens=20,
)
# Is PaliGemma not able to handle two separate images? At least we
# get output showing that both images are used.
assert (
response.generated_text == "image result for chicken on the beach"
), f"{repr(response.generated_text)}"
assert response == response_snapshot

View File

@ -0,0 +1,101 @@
import pytest
import requests
from pydantic import BaseModel
from typing import List
@pytest.fixture(scope="module")
def llama_grammar_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
num_shard=1,
disable_grammar_support=False,
use_flash_attention=False,
max_batch_prefill_tokens=3000,
) as handle:
yield handle
@pytest.fixture(scope="module")
async def llama_grammar(llama_grammar_handle):
await llama_grammar_handle.health(300)
return llama_grammar_handle.client
@pytest.mark.asyncio
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
class Weather(BaseModel):
unit: str
temperature: List[int]
# send the request
response = requests.post(
f"{llama_grammar.base_url}/v1/chat/completions",
headers=llama_grammar.headers,
json={
"model": "tgi",
"messages": [
{
"role": "system",
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
},
{
"role": "user",
"content": "What's the weather like the next 3 days in San Francisco, CA?",
},
],
"seed": 42,
"max_tokens": 500,
"response_format": {"type": "json_object", "value": Weather.schema()},
},
)
chat_completion = response.json()
called = chat_completion["choices"][0]["message"]["content"]
assert response.status_code == 200
assert (
called
== '{\n "temperature": [\n 35,\n 34,\n 36\n ],\n "unit": "°c"\n}'
)
assert chat_completion == response_snapshot
@pytest.mark.asyncio
async def test_grammar_response_format_llama_error_if_tools_not_installed(
llama_grammar,
):
class Weather(BaseModel):
unit: str
temperature: List[int]
# send the request
response = requests.post(
f"{llama_grammar.base_url}/v1/chat/completions",
headers=llama_grammar.headers,
json={
"model": "tgi",
"messages": [
{
"role": "system",
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
},
{
"role": "user",
"content": "What's the weather like the next 3 days in San Francisco, CA?",
},
],
"seed": 42,
"max_tokens": 500,
"tools": [],
"response_format": {"type": "json_object", "value": Weather.schema()},
},
)
# 422 means the server was unable to process the request because it contains invalid data.
assert response.status_code == 422
assert response.json() == {
"error": "Grammar and tools are mutually exclusive",
"error_type": "grammar and tools",
}

View File

@ -23,6 +23,12 @@ def get_chicken():
return f"data:image/png;base64,{encoded_string.decode('utf-8')}" return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_idefics(idefics, response_snapshot): async def test_idefics(idefics, response_snapshot):
chicken = get_chicken() chicken = get_chicken()
@ -39,6 +45,21 @@ async def test_idefics(idefics, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_idefics_two_images(idefics, response_snapshot):
chicken = get_chicken()
cow_beach = get_cow_beach()
response = await idefics.generate(
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
max_new_tokens=20,
)
assert (
response.generated_text == " The cow and chicken are on a beach."
), f"{repr(response.generated_text)}"
assert response == response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_idefics_load(idefics, generate_load, response_snapshot): async def test_idefics_load(idefics, generate_load, response_snapshot):
chicken = get_chicken() chicken = get_chicken()

View File

@ -9,6 +9,12 @@ def get_chicken():
return f"data:image/png;base64,{encoded_string.decode('utf-8')}" return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_idefics2_next_handle(launcher): def flash_idefics2_next_handle(launcher):
with launcher( with launcher(
@ -38,6 +44,23 @@ async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot):
chicken = get_chicken()
cow_beach = get_cow_beach()
response = await flash_idefics2_next.generate(
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
max_new_tokens=20,
)
assert (
response.generated_text
== " The cow is standing on the beach and the chicken is sitting on a pile of money."
), f"{repr(response.generated_text)}"
assert response.details.generated_tokens == 20
assert response == response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snapshot): async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snapshot):

View File

@ -17,14 +17,32 @@ use std::thread::sleep;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use std::{fs, io}; use std::{fs, io};
use thiserror::Error; use thiserror::Error;
use tracing_subscriber::EnvFilter; use tracing_subscriber::{filter::LevelFilter, EnvFilter};
mod env_runtime; mod env_runtime;
#[derive(Deserialize)]
struct RawConfig {
max_position_embeddings: Option<usize>,
n_positions: Option<usize>,
max_seq_len: Option<usize>,
}
#[derive(Deserialize)] #[derive(Deserialize)]
struct Config { struct Config {
max_position_embeddings: Option<usize>, max_position_embeddings: Option<usize>,
max_seq_len: Option<usize>, }
impl From<RawConfig> for Config {
fn from(other: RawConfig) -> Self {
let max_position_embeddings = other
.max_position_embeddings
.or(other.max_seq_len)
.or(other.n_positions);
Config {
max_position_embeddings,
}
}
} }
#[derive(Clone, Copy, Debug, ValueEnum)] #[derive(Clone, Copy, Debug, ValueEnum)]
@ -37,11 +55,17 @@ enum Quantization {
/// Should be a drop-in replacement to bitsandbytes with much better performance. /// Should be a drop-in replacement to bitsandbytes with much better performance.
/// Kernels are from <https://github.com/NetEase-FuXi/EETQ.git> /// Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
Eetq, Eetq,
/// Variable bit quantization. Requires a specific EXL2 quantized model:
/// <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does
/// not support tensor parallelism (num_shard > 1).
Exl2,
/// 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. /// 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>.
/// text-generation-inference will use exllama (faster) kernels wherever possible, and use /// text-generation-inference will use exllama (faster) kernels wherever possible, and use
/// triton kernel (wider support) when it's not. /// triton kernel (wider support) when it's not.
/// AWQ has faster kernels. /// AWQ has faster kernels.
Gptq, Gptq,
/// 4 bit quantization. Requires a specific Marlin quantized model: <https://hf.co/models?search=marlin>.
Marlin,
/// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, /// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half,
/// but it is known that the model will be much slower to run than the native f16. /// but it is known that the model will be much slower to run than the native f16.
#[deprecated( #[deprecated(
@ -77,9 +101,15 @@ impl std::fmt::Display for Quantization {
Quantization::BitsandbytesFP4 => { Quantization::BitsandbytesFP4 => {
write!(f, "bitsandbytes-fp4") write!(f, "bitsandbytes-fp4")
} }
Quantization::Exl2 => {
write!(f, "exl2")
}
Quantization::Gptq => { Quantization::Gptq => {
write!(f, "gptq") write!(f, "gptq")
} }
Quantization::Marlin => {
write!(f, "marlin")
}
Quantization::Awq => { Quantization::Awq => {
write!(f, "awq") write!(f, "awq")
} }
@ -228,7 +258,7 @@ struct Args {
max_stop_sequences: usize, max_stop_sequences: usize,
/// This is the maximum allowed value for clients to set `top_n_tokens`. /// This is the maximum allowed value for clients to set `top_n_tokens`.
/// `top_n_tokens is used to return information about the the `n` most likely /// `top_n_tokens` is used to return information about the the `n` most likely
/// tokens at each generation step, instead of just the sampled token. This /// tokens at each generation step, instead of just the sampled token. This
/// information can be used for downstream tasks like for classification or /// information can be used for downstream tasks like for classification or
/// ranking. /// ranking.
@ -470,7 +500,9 @@ fn shard_manager(
rope_factor: Option<f32>, rope_factor: Option<f32>,
max_total_tokens: usize, max_total_tokens: usize,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
max_input_tokens: usize,
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
log_level: LevelFilter,
status_sender: mpsc::Sender<ShardStatus>, status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
_shutdown_sender: mpsc::Sender<()>, _shutdown_sender: mpsc::Sender<()>,
@ -493,7 +525,7 @@ fn shard_manager(
"--uds-path".to_string(), "--uds-path".to_string(),
uds_path, uds_path,
"--logger-level".to_string(), "--logger-level".to_string(),
"INFO".to_string(), log_level.to_string().to_uppercase(),
"--json-output".to_string(), "--json-output".to_string(),
]; ];
@ -551,6 +583,10 @@ fn shard_manager(
shard_args.push(otlp_endpoint); shard_args.push(otlp_endpoint);
} }
// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
shard_args.push("--max-input-tokens".to_string());
shard_args.push(max_input_tokens.to_string());
// Copy current process env // Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
@ -781,13 +817,13 @@ struct PythonLogMessage {
impl PythonLogMessage { impl PythonLogMessage {
fn trace(&self) { fn trace(&self) {
match self.record.level.name { match self.record.level.name {
PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text), PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text.trim_end()),
PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text), PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text.trim_end()),
PythonLogLevelEnum::Info => tracing::info!("{}", self.text), PythonLogLevelEnum::Info => tracing::info!("{}", self.text.trim_end()),
PythonLogLevelEnum::Success => tracing::info!("{}", self.text), PythonLogLevelEnum::Success => tracing::info!("{}", self.text.trim_end()),
PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text), PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text.trim_end()),
PythonLogLevelEnum::Error => tracing::error!("{}", self.text), PythonLogLevelEnum::Error => tracing::error!("{}", self.text.trim_end()),
PythonLogLevelEnum::Critical => tracing::error!("{}", self.text), PythonLogLevelEnum::Critical => tracing::error!("{}", self.text.trim_end()),
} }
} }
} }
@ -1007,6 +1043,8 @@ fn spawn_shards(
args: &Args, args: &Args,
cuda_graphs: Vec<usize>, cuda_graphs: Vec<usize>,
max_total_tokens: usize, max_total_tokens: usize,
max_input_tokens: usize,
max_log_level: LevelFilter,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>, shutdown_receiver: &mpsc::Receiver<()>,
shutdown_sender: mpsc::Sender<()>, shutdown_sender: mpsc::Sender<()>,
@ -1067,7 +1105,9 @@ fn spawn_shards(
rope_factor, rope_factor,
max_total_tokens, max_total_tokens,
max_batch_size, max_batch_size,
max_input_tokens,
otlp_endpoint, otlp_endpoint,
max_log_level,
status_sender, status_sender,
shutdown, shutdown,
shutdown_sender, shutdown_sender,
@ -1298,8 +1338,22 @@ fn main() -> Result<(), LauncherError> {
let args: Args = Args::parse(); let args: Args = Args::parse();
// Filter events with LOG_LEVEL // Filter events with LOG_LEVEL
let env_filter = let varname = "LOG_LEVEL";
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info")); let env_filter = if let Ok(log_level) = std::env::var(varname) {
// Override to avoid simple logs to be spammed with tokio level informations
let log_level = match &log_level[..] {
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
"info" => "text_generation_launcher=info,text_generation_router=info",
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
log_level => log_level,
};
EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.parse_lossy(log_level)
} else {
EnvFilter::new("info")
};
let max_log_level = env_filter.max_level_hint().unwrap_or(LevelFilter::INFO);
if args.json_output { if args.json_output {
tracing_subscriber::fmt() tracing_subscriber::fmt()
@ -1342,13 +1396,13 @@ fn main() -> Result<(), LauncherError> {
}; };
let content = std::fs::read_to_string(filename)?; let content = std::fs::read_to_string(filename)?;
let config: Config = serde_json::from_str(&content)?; let config: RawConfig = serde_json::from_str(&content)?;
let config: Config = config.into();
// Quantization usually means you're even more RAM constrained. // Quantization usually means you're even more RAM constrained.
let max_default = 4096; let max_default = 4096;
let max_position_embeddings = match (config.max_position_embeddings, config.max_seq_len) { if let Some(max_position_embeddings) = config.max_position_embeddings {
(Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => {
if max_position_embeddings > max_default { if max_position_embeddings > max_default {
let max = max_position_embeddings; let max = max_position_embeddings;
if args.max_input_tokens.is_none() if args.max_input_tokens.is_none()
@ -1357,18 +1411,15 @@ fn main() -> Result<(), LauncherError> {
{ {
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
} }
max_default Ok(max_default)
} else { } else {
max_position_embeddings
}
}
_ => {
return Err(Box::new(LauncherError::ArgumentValidation(
"no max defined".to_string(),
)));
}
};
Ok(max_position_embeddings) Ok(max_position_embeddings)
}
} else {
Err(Box::new(LauncherError::ArgumentValidation(
"no max defined".to_string(),
)))
}
}; };
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096); let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);
@ -1462,6 +1513,11 @@ fn main() -> Result<(), LauncherError> {
let num_shard = find_num_shards(args.sharded, args.num_shard)?; let num_shard = find_num_shards(args.sharded, args.num_shard)?;
if num_shard > 1 { if num_shard > 1 {
if matches!(args.quantize, Some(Quantization::Exl2)) {
return Err(LauncherError::ArgumentValidation(
"Sharding is currently not supported with `exl2` quantization".into(),
));
}
tracing::info!("Sharding model on {num_shard} processes"); tracing::info!("Sharding model on {num_shard} processes");
} }
@ -1524,6 +1580,8 @@ fn main() -> Result<(), LauncherError> {
&args, &args,
cuda_graphs, cuda_graphs,
max_total_tokens, max_total_tokens,
max_input_tokens,
max_log_level,
shutdown.clone(), shutdown.clone(),
&shutdown_receiver, &shutdown_receiver,
shutdown_sender, shutdown_sender,

265
proto/v3/generate.proto Normal file
View File

@ -0,0 +1,265 @@
syntax = "proto3";
package generate.v3;
service TextGenerationService {
/// Model Info
rpc Info (InfoRequest) returns (InfoResponse) {}
/// Service discovery
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
/// Empties batch cache
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
/// Remove requests from a cached batch
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
/// Warmup the model and compute max cache size
rpc Warmup (WarmupRequest) returns (WarmupResponse);
/// Prefill batch and decode first token
rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Decode token for a list of prefilled batches
rpc Decode (DecodeRequest) returns (DecodeResponse);
/// Health check
rpc Health (HealthRequest) returns (HealthResponse);
}
message HealthRequest {}
message HealthResponse {}
/// Empty request
message InfoRequest {}
message InfoResponse {
bool requires_padding = 1;
string dtype = 2;
string device_type = 3;
optional uint32 window_size = 4;
uint32 speculate = 5;
}
/// Empty request
message ServiceDiscoveryRequest {}
message ServiceDiscoveryResponse {
/// Other shards urls
repeated string urls = 1;
}
message ClearCacheRequest {
/// Optional batch id
optional uint64 id = 1;
}
/// Empty response
message ClearCacheResponse {}
message Image {
/// Binary image data.
bytes data = 1;
/// Image MIME type.
string mimetype = 2;
}
message InputChunk {
oneof chunk {
/// Plain text data
string text = 1;
/// Image data
Image image = 2;
}
}
message Input {
repeated InputChunk chunks = 1;
}
enum GrammarType {
GRAMMAR_TYPE_NONE = 0;
GRAMMAR_TYPE_JSON = 1;
GRAMMAR_TYPE_REGEX = 2;
}
message NextTokenChooserParameters {
/// exponential scaling output probability distribution
float temperature = 1;
/// restricting to the k highest probability elements
uint32 top_k = 2;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float top_p = 3;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float typical_p = 4;
/// apply sampling on the logits
bool do_sample = 5;
/// random seed for sampling
uint64 seed = 6;
/// repetition penalty
float repetition_penalty = 7;
/// frequency penalty
float frequency_penalty = 9;
/// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8;
/// grammar (applied if not empty)
string grammar = 10;
/// grammar type
GrammarType grammar_type = 11;
}
message StoppingCriteriaParameters {
/// Maximum number of generated tokens
uint32 max_new_tokens = 1;
/// Optional stopping sequences
repeated string stop_sequences = 2;
/// Ignore end of sequence token
/// used for benchmarking
bool ignore_eos_token = 3;
}
message Request {
/// Request ID
uint64 id = 1;
/// The generation context as chunks
Input input_chunks = 8;
/// The generation context, stringified input_chunks
string inputs = 2;
/// Context truncation
uint32 truncate = 3;
/// Next Token Chooser Parameters
NextTokenChooserParameters parameters = 4;
/// Stopping Criteria Parameters
StoppingCriteriaParameters stopping_parameters = 5;
/// Return prefill logprobs
bool prefill_logprobs = 6;
/// Return most likely n tokens
uint32 top_n_tokens = 7;
/// Paged attention blocks
repeated uint32 blocks = 9;
/// Paged attention slots
repeated uint32 slots = 10;
}
message Batch {
/// Batch ID
uint64 id = 1;
/// Individual requests
repeated Request requests = 2;
/// Batch size (==len(requests))
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
/// Maximum number of Paged Attention blocks
uint32 max_blocks = 5;
}
message CachedBatch {
/// Batch ID
uint64 id = 1;
/// Individual requests ids
repeated uint64 request_ids = 2;
/// Batch size (==len(requests))
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
}
enum FinishReason {
FINISH_REASON_LENGTH = 0;
FINISH_REASON_EOS_TOKEN = 1;
FINISH_REASON_STOP_SEQUENCE = 2;
}
message GeneratedText {
/// Output
string text = 1;
/// Number of generated tokens
uint32 generated_tokens = 2;
/// Finish reason
FinishReason finish_reason = 3;
/// Seed
optional uint64 seed = 4;
}
message Tokens {
/// Token IDs
repeated uint32 ids = 1;
/// Logprobs
repeated float logprobs = 2;
/// tokens
repeated string texts = 3;
/// special
repeated bool is_special = 4;
}
message Generation {
/// Request ID
uint64 request_id = 1;
/// Prefill tokens (optional)
Tokens prefill_tokens = 2;
Tokens tokens = 3;
/// Complete generated text
optional GeneratedText generated_text = 4;
/// Top tokens
repeated Tokens top_tokens = 5;
}
message FilterBatchRequest {
/// Batch ID
uint64 batch_id = 1;
/// Requests to keep
repeated uint64 request_ids = 2;
}
message FilterBatchResponse {
/// Filtered Batch (cached)
CachedBatch batch = 1;
}
message PrefillRequest {
/// Batch
Batch batch = 1;
}
message PrefillResponse {
/// Generation
repeated Generation generations = 1;
/// Next batch (cached)
optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds
uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds
uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds
uint64 total_ns = 5;
}
message DecodeRequest {
/// Cached batches
repeated CachedBatch batches = 1;
}
message DecodeResponse {
/// Decodes
repeated Generation generations = 1;
/// Next batch (cached)
optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds
uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds
uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds
uint64 total_ns = 5;
/// Concatenate elapsed time in nanoseconds
optional uint64 concat_ns = 6;
}
message WarmupRequest {
/// Batch to warmup on
Batch batch = 1;
uint32 max_input_length = 2;
uint32 max_prefill_tokens = 3;
uint32 max_total_tokens = 4;
}
message WarmupResponse {
/// Maximum number of tokens supported by the model
optional uint32 max_supported_total_tokens = 1;
}

View File

@ -16,8 +16,8 @@ path = "src/main.rs"
[dependencies] [dependencies]
async-stream = "0.3.5" async-stream = "0.3.5"
axum = { version = "0.6.20", features = ["json"] } axum = { version = "0.7", features = ["json"] }
axum-tracing-opentelemetry = "0.14.1" axum-tracing-opentelemetry = "0.16"
text-generation-client = { path = "client" } text-generation-client = { path = "client" }
clap = { version = "4.4.5", features = ["derive", "env"] } clap = { version = "4.4.5", features = ["derive", "env"] }
futures = "0.3.28" futures = "0.3.28"
@ -36,20 +36,21 @@ thiserror = "1.0.48"
tokenizers = { workspace = true} tokenizers = { workspace = true}
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.14" tokio-stream = "0.1.14"
tower-http = { version = "0.4.4", features = ["cors"] } tower-http = { version = "0.5.1", features = ["cors"] }
tracing = "0.1.37" tracing = "0.1.37"
tracing-opentelemetry = "0.21.0" tracing-opentelemetry = "0.21.0"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
utoipa = { version = "3.5.0", features = ["axum_extras"] } utoipa = { version = "4.2.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
ngrok = { version = "0.13.1", features = ["axum"], optional = true } ngrok = { version = "0.13.1", features = ["axum"], optional = true }
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", rev = "5cd4efb" } minijinja = { version = "2.0.2" }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
futures-util = "0.3.30" futures-util = "0.3.30"
regex = "1.10.3" regex = "1.10.3"
once_cell = "1.19.0" once_cell = "1.19.0"
image = "0.25.1" image = "0.25.1"
base64 = "0.22.0" base64 = { workspace = true }
[build-dependencies] [build-dependencies]
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
@ -58,3 +59,4 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
default = ["ngrok"] default = ["ngrok"]
ngrok = ["dep:ngrok"] ngrok = ["dep:ngrok"]
google = [] google = []
kserve = []

View File

@ -6,6 +6,8 @@ authors.workspace = true
homepage.workspace = true homepage.workspace = true
[dependencies] [dependencies]
async-trait = "^0.1"
base64 = { workspace = true }
futures = "^0.3" futures = "^0.3"
grpc-metadata = { path = "../grpc-metadata" } grpc-metadata = { path = "../grpc-metadata" }
prost = "^0.12" prost = "^0.12"

View File

@ -1,18 +1,34 @@
use std::fs; use std::fs;
fn main() -> Result<(), Box<dyn std::error::Error>> { fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("cargo:rerun-if-changed=../../proto/generate.proto"); println!("cargo:rerun-if-changed=../../proto/");
fs::create_dir("src/pb").unwrap_or(());
fs::create_dir_all("src/v2/pb").unwrap_or(());
let mut config = prost_build::Config::new(); let mut config = prost_build::Config::new();
config.protoc_arg("--experimental_allow_proto3_optional"); config.protoc_arg("--experimental_allow_proto3_optional");
tonic_build::configure() tonic_build::configure()
.build_client(true) .build_client(true)
.build_server(false) .build_server(false)
.out_dir("src/pb") .out_dir("src/v2/pb")
.include_file("mod.rs") .include_file("mod.rs")
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"]) .compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
.map_err(|e| match e.kind(){
std::io::ErrorKind::NotFound => {panic!("`protoc` not found, install libprotoc")},
std::io::ErrorKind::Other => {panic!("`protoc` version unsupported, upgrade protoc: https://github.com/protocolbuffers/protobuf/releases")},
e => {e}
}).unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
fs::create_dir_all("src/v3/pb").unwrap_or(());
let mut config = prost_build::Config::new();
config.protoc_arg("--experimental_allow_proto3_optional");
tonic_build::configure()
.build_client(true)
.build_server(false)
.out_dir("src/v3/pb")
.include_file("mod.rs")
.compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"])
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
Ok(()) Ok(())

View File

@ -1,22 +1,35 @@
//! Text Generation gRPC client library //! Text Generation gRPC client library
mod client; pub mod v2;
#[allow(clippy::derive_partial_eq_without_eq)] pub mod v3;
mod pb;
mod sharded_client;
pub use client::Client; use async_trait::async_trait;
pub use pb::generate::v2::HealthResponse; use base64::{engine::general_purpose::STANDARD, Engine};
pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;
use thiserror::Error; use thiserror::Error;
use tonic::transport; use tonic::transport;
use tonic::Status; use tonic::Status;
pub use v3::{Chunk, Image, Input, InputChunk};
#[async_trait]
pub trait Health {
/// Check if a generate server is healthy by asking it to allocate a tensor on device
async fn device_health(&self) -> Result<()>;
/// Check if a generate server is healthy by doing a forward pass.
/// EXPENSIVE
async fn model_health(&self) -> Result<()>;
}
#[derive(Debug)]
pub struct ShardInfo {
pub requires_padding: bool,
pub dtype: String,
pub device_type: String,
pub window_size: Option<u32>,
pub speculate: u32,
}
#[derive(Error, Debug, Clone)] #[derive(Error, Debug, Clone)]
pub enum ClientError { pub enum ClientError {
#[error("Could not connect to Text Generation server: {0}")] #[error("Could not connect to Text Generation server: {0}")]
@ -43,4 +56,36 @@ impl From<transport::Error> for ClientError {
} }
} }
// Small convenience re-wrapping of `Chunk`.
impl From<Chunk> for InputChunk {
fn from(chunk: Chunk) -> Self {
InputChunk { chunk: Some(chunk) }
}
}
/// Convert input chunks to a stringly-typed input for backwards
/// compat for backends that haven't implemented chunked inputs.
pub trait ChunksToString {
/// Convert chunks to string.
fn chunks_to_string(&self) -> String;
}
impl ChunksToString for Vec<InputChunk> {
fn chunks_to_string(&self) -> String {
let mut output = String::new();
self.iter().for_each(|c| match &c.chunk {
Some(Chunk::Text(text)) => output.push_str(text),
Some(Chunk::Image(Image { data, mimetype })) => {
let encoded = STANDARD.encode(data);
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
}
// We don't create empty chunks, so this should be unreachable.
None => unreachable!("Chunks should never be empty"),
});
output
}
}
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
pub type Result<T> = std::result::Result<T, ClientError>; pub type Result<T> = std::result::Result<T, ClientError>;

View File

@ -1 +0,0 @@
*.rs

View File

@ -1,8 +1,11 @@
/// Single shard Client /// Single shard Client
use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient; use crate::v2::pb;
use crate::pb::generate::v2::*; use crate::{ClientError, Result};
use crate::Result;
use crate::WARMUP_IMAGE_BASE64;
use grpc_metadata::InjectTelemetryContext; use grpc_metadata::InjectTelemetryContext;
use pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
use pb::generate::v2::*;
use std::cmp::min; use std::cmp::min;
use std::time::Duration; use std::time::Duration;
use tonic::transport::{Channel, Uri}; use tonic::transport::{Channel, Uri};
@ -42,7 +45,9 @@ impl Client {
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn service_discovery(&mut self) -> Result<Vec<String>> { pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context(); let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
let response = self.stub.service_discovery(request).await?; let response = self.stub.service_discovery(request).await.map_err(|_| {
ClientError::Connection("Server does not support v2 interface".to_string())
})?;
let urls = response let urls = response
.into_inner() .into_inner()
.urls .urls
@ -118,13 +123,15 @@ impl Client {
if n_tokens == 0 { if n_tokens == 0 {
// 1 request is enough to test vision heads. // 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation. // Sending images on other queries messes up easily with truncation.
inputs.push_str("![](data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=)"); inputs.push_str(&format!(
"![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})",
));
} }
requests.push(Request { requests.push(Request {
id: 0, id: 0,
// We truncate the input on the server side to be sure that it has the correct size
inputs, inputs,
// We truncate the input on the server side to be sure that it has the correct size
truncate, truncate,
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {

View File

@ -0,0 +1,13 @@
#[allow(clippy::derive_partial_eq_without_eq)]
mod pb;
mod client;
mod sharded_client;
pub use client::Client;
pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, InfoResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;

1
router/client/src/v2/pb/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
*

View File

@ -1,10 +1,17 @@
use crate::client::{DecodeTimings, PrefillTimings};
/// Multi shard Client /// Multi shard Client
use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo}; use crate::{v2, Health, ShardInfo};
use crate::{ClientError, Result}; use crate::{ClientError, Result};
use crate::v2::InfoResponse;
use async_trait::async_trait;
use futures::future::join_all; use futures::future::join_all;
use tonic::transport::Uri; use tonic::transport::Uri;
use tracing::instrument; use tracing::instrument;
use v2::client::{DecodeTimings, PrefillTimings};
use v2::{
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
/// Text Generation Inference gRPC multi client /// Text Generation Inference gRPC multi client
@ -47,7 +54,7 @@ impl ShardedClient {
.iter_mut() .iter_mut()
.map(|client| client.info()) .map(|client| client.info())
.collect(); .collect();
join_all(futures).await.pop().unwrap() join_all(futures).await.pop().unwrap().map(ShardInfo::from)
} }
/// GRPC health check /// GRPC health check
@ -185,3 +192,60 @@ impl ShardedClient {
Ok((generations, next_batch, timings)) Ok((generations, next_batch, timings))
} }
} }
impl From<InfoResponse> for ShardInfo {
fn from(value: InfoResponse) -> Self {
Self {
requires_padding: value.requires_padding,
dtype: value.dtype,
device_type: value.device_type,
window_size: value.window_size,
speculate: value.speculate,
}
}
}
#[async_trait]
impl Health for ShardedClient {
async fn device_health(&self) -> Result<()> {
self.clone().health().await?;
Ok(())
}
async fn model_health(&self) -> Result<()> {
// Dummy batch of 1 token and 1 generated token
let liveness_request = Request {
id: u64::MAX,
inputs: "liveness".to_string(),
truncate: 10,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
stop_sequences: vec![],
ignore_eos_token: false,
}),
top_n_tokens: 0,
};
let batch = Batch {
id: u64::MAX,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
};
self.clone().prefill(batch).await?;
Ok(())
}
}

View File

@ -0,0 +1,282 @@
use crate::v3::{pb, Chunk};
use crate::{ClientError, Result, WARMUP_IMAGE_BASE64};
/// Single shard Client
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use grpc_metadata::InjectTelemetryContext;
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
use pb::generate::v3::*;
use std::cmp::min;
use std::time::Duration;
use tonic::transport::{Channel, Uri};
use tracing::instrument;
/// Text Generation Inference gRPC client
#[derive(Debug, Clone)]
pub struct Client {
stub: TextGenerationServiceClient<Channel>,
}
impl Client {
/// Returns a client connected to the given url
pub async fn connect(uri: Uri) -> Result<Self> {
let channel = Channel::builder(uri).connect().await?;
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
}
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result<Self> {
let channel = Channel::from_shared("http://[::]:50051".to_string())
.unwrap()
.connect_with_connector(tower::service_fn(move |_: Uri| {
tokio::net::UnixStream::connect(path.clone())
}))
.await?;
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
}
/// Returns a list of uris or unix sockets of all shards
#[instrument(skip(self))]
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
let response = self.stub.service_discovery(request).await.map_err(|_| {
ClientError::Connection("Server does not support v3 interface".to_string())
})?;
let urls = response
.into_inner()
.urls
.into_iter()
// Remove unix socket prefix
.map(|url| match url.strip_prefix("unix://") {
None => url,
Some(stripped_url) => stripped_url.to_string(),
})
.collect();
Ok(urls)
}
/// Get model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result<InfoResponse> {
let request = tonic::Request::new(InfoRequest {}).inject_context();
let response = self.stub.info(request).await?.into_inner();
Ok(response)
}
/// Get model health
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
let request = tonic::Request::new(HealthRequest {}).inject_context();
let response = self.stub.health(request).await?.into_inner();
Ok(response)
}
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
self.stub.clear_cache(request).await?;
Ok(())
}
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
let request = tonic::Request::new(FilterBatchRequest {
batch_id,
request_ids,
})
.inject_context();
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
Ok(filtered_batch.batch)
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip_all)]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let mut n_tokens = 0;
let mut requests = Vec::new();
// Create requests
while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut input_chunks = Vec::new();
input_chunks
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
if n_tokens == 0 {
input_chunks.push(
Chunk::Image(Image {
// Safe unwrap, because we control the data.
data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(),
mimetype: "image/jpeg;base64".to_string(),
})
.into(),
);
}
// Send stringly-typed inputs for compatibility for backends that haven't
// been updated to support chunks.
let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
inputs.push_str(&format!(
"![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})",
));
}
requests.push(Request {
id: 0,
inputs,
input_chunks: Some(Input {
chunks: input_chunks,
}),
// We truncate the input on the server side to be sure that it has the correct size
truncate,
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
top_k: 10,
top_p: 0.9,
typical_p: 0.9,
do_sample: false,
seed: 0,
repetition_penalty: 1.2,
frequency_penalty: 0.1,
watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,
stop_sequences: vec![],
ignore_eos_token: true,
}),
prefill_logprobs: true,
top_n_tokens: 20,
});
n_tokens += max_input_length;
// Check max_batch_size
if Some(requests.len()) == max_batch_size {
break;
}
}
let batch = Batch {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: max_input_length,
max_blocks: 0,
};
let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_input_length,
max_prefill_tokens,
max_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens)
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((
response.generations,
response.batch,
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
))
}
/// Generate one token for each request in the given cached batches
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
pub async fn decode(
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self.stub.decode(request).await?.into_inner();
Ok((
response.generations,
response.batch,
DecodeTimings::new(
response.concat_ns,
response.forward_ns,
response.decode_ns,
response.total_ns,
),
))
}
}
pub struct PrefillTimings {
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl PrefillTimings {
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
}
}
pub struct DecodeTimings {
pub concat: Option<Duration>,
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl DecodeTimings {
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
concat: concat_ns.map(Duration::from_nanos),
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
}
}

View File

@ -0,0 +1,13 @@
#[allow(clippy::derive_partial_eq_without_eq)]
mod pb;
mod client;
mod sharded_client;
pub use client::Client;
pub use pb::generate::v3::{
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;

1
router/client/src/v3/pb/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
*

View File

@ -0,0 +1,258 @@
/// Multi shard Client
use crate::{v3, Health, ShardInfo};
use crate::{ClientError, Result};
use crate::v3::{Chunk, InfoResponse, Input};
use async_trait::async_trait;
use futures::future::join_all;
use tonic::transport::Uri;
use tracing::instrument;
use v3::client::{DecodeTimings, PrefillTimings};
use v3::{
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
#[derive(Debug, Clone)]
/// Text Generation Inference gRPC multi client
pub struct ShardedClient {
clients: Vec<Client>,
}
impl ShardedClient {
fn new(clients: Vec<Client>) -> Self {
Self { clients }
}
/// Create a new ShardedClient from a master client. The master client will communicate with
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
async fn from_master_client(mut master_client: Client) -> Result<Self> {
// Get all uris/unix sockets from the master client
let uris = master_client.service_discovery().await?;
let futures = uris.into_iter().map(Client::connect_uds);
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
Ok(Self::new(clients?))
}
/// Returns a client connected to the given uri
pub async fn connect(uri: Uri) -> Result<Self> {
let master_client = Client::connect(uri).await?;
Self::from_master_client(master_client).await
}
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result<Self> {
let master_client = Client::connect_uds(path).await?;
Self::from_master_client(master_client).await
}
/// Get the model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result<ShardInfo> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.info())
.collect();
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
}
/// GRPC health check
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.health())
.collect();
join_all(futures).await.pop().unwrap()
}
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.clear_cache(batch_id))
.collect();
join_all(futures).await.into_iter().collect()
}
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| {
Box::pin(client.warmup(
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_size,
))
})
.collect();
// Take the minimum value
let results = join_all(futures)
.await
.into_iter()
.collect::<Result<Vec<Option<u32>>>>()?;
Ok(results.into_iter().flatten().min())
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
join_all(futures).await.into_iter().collect();
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
/// Generate one token for each request in the given cached batches
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
pub async fn decode(
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.decode(batches.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
join_all(futures).await.into_iter().collect();
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
}
impl From<InfoResponse> for ShardInfo {
fn from(value: InfoResponse) -> Self {
Self {
requires_padding: value.requires_padding,
dtype: value.dtype,
device_type: value.device_type,
window_size: value.window_size,
speculate: value.speculate,
}
}
}
#[async_trait]
impl Health for ShardedClient {
async fn device_health(&self) -> Result<()> {
self.clone().health().await?;
Ok(())
}
async fn model_health(&self) -> Result<()> {
// Dummy batch of 1 token and 1 generated token
let liveness_request = Request {
id: u64::MAX,
inputs: "liveness".to_string(),
input_chunks: Some(Input {
chunks: vec![Chunk::Text("liveness".into()).into()],
}),
truncate: 10,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
stop_sequences: vec![],
ignore_eos_token: false,
}),
top_n_tokens: 0,
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
};
let batch = Batch {
id: u64::MAX,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
max_blocks: 1,
};
self.clone().prefill(batch).await?;
Ok(())
}
}

View File

@ -4,9 +4,9 @@ use serde::{Deserialize, Serialize};
#[serde(tag = "model_type")] #[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub struct LlavaNext { pub struct LlavaNext {
text_config: TextConfig, pub(crate) text_config: TextConfig,
vision_config: VisionConfig, pub(crate) vision_config: VisionConfig,
image_grid_pinpoints: Vec<(usize, usize)>, pub(crate) image_grid_pinpoints: Vec<(usize, usize)>,
} }
fn get_anyres_image_grid_shape( fn get_anyres_image_grid_shape(
@ -119,13 +119,13 @@ impl Idefics2 {
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub struct PaliTextConfig { pub struct PaliTextConfig {
num_image_tokens: usize, pub(crate) num_image_tokens: usize,
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub struct Paligemma { pub struct Paligemma {
text_config: PaliTextConfig, pub(crate) text_config: PaliTextConfig,
} }
impl Paligemma { impl Paligemma {
@ -175,8 +175,8 @@ pub struct TextConfig {}
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub struct VisionConfig { pub struct VisionConfig {
image_size: usize, pub(crate) image_size: usize,
patch_size: usize, pub(crate) patch_size: usize,
} }
#[cfg(test)] #[cfg(test)]

View File

@ -1,72 +0,0 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use text_generation_client::GrammarType as ProtoGrammarType;
use text_generation_client::{
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
};
// Note: Request ids and batch ids cannot collide.
const LIVENESS_ID: u64 = u64::MAX;
const BATCH_ID: u64 = u64::MAX;
#[derive(Clone, Debug)]
pub(crate) struct Health {
client: ShardedClient,
generation_health: Arc<AtomicBool>,
}
impl Health {
pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self {
Self {
client,
generation_health,
}
}
pub(crate) async fn check(&mut self) -> bool {
if self.generation_health.load(Ordering::SeqCst) {
// Generation is healthy, we only check that the shards are answering gRPC calls
self.client.health().await.is_ok()
} else {
// Generation is unhealthy or have not sent any generation request yet
// Dummy batch of 1 token and 1 generated token
let liveness_request = Request {
id: LIVENESS_ID,
inputs: "liveness".to_string(),
truncate: 10,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
stop_sequences: vec![],
ignore_eos_token: false,
}),
top_n_tokens: 0,
};
let batch = Batch {
id: BATCH_ID,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
};
// Skips the queue
let value = self.client.prefill(batch).await.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
value
}
}
}

View File

@ -0,0 +1,34 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use text_generation_client::Health;
#[derive(Clone)]
pub(crate) struct HealthCheck {
client: Arc<dyn Health + Send + Sync>,
generation_health: Arc<AtomicBool>,
}
impl HealthCheck {
pub(crate) fn new(
client: Arc<dyn Health + Send + Sync>,
generation_health: Arc<AtomicBool>,
) -> Self {
Self {
client,
generation_health,
}
}
pub(crate) async fn check(&mut self) -> bool {
let value = if self.generation_health.load(Ordering::SeqCst) {
// Generation is healthy, we only check that the shards can allocate on device
self.client.device_health().await
} else {
self.client.model_health().await
}
.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
value
}
}

519
router/src/infer/mod.rs Normal file
View File

@ -0,0 +1,519 @@
mod health;
pub(crate) mod v2;
pub(crate) mod v3;
pub(crate) use health::HealthCheck;
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
use crate::{
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
HubTokenizerConfig, Message, MessageChunk, PrefillToken, Text, TextMessage, Token,
};
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
use futures::future::try_join_all;
use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat;
use serde_json::{json, Map, Value};
use std::collections::HashMap;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
use tracing::instrument;
pub(crate) trait Scheduler {
fn schedule(
&self,
request: ValidGenerateRequest,
permit: OwnedSemaphorePermit,
) -> Result<GenerateStreamResponse, InferError>;
}
/// Inference struct
#[derive(Clone)]
pub struct Infer {
/// Validation
validation: Validation,
/// Request scheduler
scheduler: Arc<dyn Scheduler + Send + Sync>,
/// Chat template
chat_template: Option<ChatTemplate>,
/// Inference limit
limit_concurrent_requests: Arc<Semaphore>,
}
impl Infer {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
scheduler: Arc<dyn Scheduler + Send + Sync>,
validation: Validation,
max_concurrent_requests: usize,
tokenizer_config: HubTokenizerConfig,
processor_config: HubProcessorConfig,
) -> Self {
let chat_template = tokenizer_config
.chat_template
.or(processor_config.chat_template)
.and_then(|t| match t {
ChatTemplateVersions::Single(template) => Some(template),
ChatTemplateVersions::Multiple(templates) => templates
.into_iter()
.find(|t| t.name == "default")
.map(|t| t.template),
})
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
// Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
Self {
validation,
scheduler,
chat_template,
limit_concurrent_requests: semaphore,
}
}
/// Add a new request to the queue and return a stream of InferStreamResponse
#[instrument(skip_all)]
pub(crate) async fn generate_stream(
&self,
request: GenerateRequest,
) -> Result<GenerateStreamResponse, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore
let permit = self
.clone()
.limit_concurrent_requests
.try_acquire_owned()
.map_err(|err| {
metrics::increment_counter!("tgi_request_failure", "err" => "overloaded");
tracing::error!("{err}");
err
})?;
// Validate request
let valid_request = self.validation.validate(request).await.map_err(|err| {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
err
})?;
self.scheduler.schedule(valid_request, permit)
}
/// Tokenizer the input
#[instrument(skip_all)]
pub(crate) async fn tokenize(
&self,
request: GenerateRequest,
) -> Result<Option<tokenizers::Encoding>, InferError> {
// Tokenize request
let inputs = request.inputs;
let truncate = request.parameters.truncate;
let encoding = self
.validation
.tokenize(inputs, truncate)
.await
.map_err(|err| {
tracing::error!("Tokenization {err}");
err
})?;
// Return Encoding
Ok(encoding.map(|(encoding, _)| encoding))
}
/// Apply the chat template to the chat request
#[instrument(skip_all)]
pub(crate) fn apply_chat_template(
&self,
messages: Vec<Message>,
grammar_with_prompt: Option<(GrammarType, String)>,
) -> Result<String, InferError> {
self.chat_template
.as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.apply(messages, grammar_with_prompt)
.map_err(|e| {
metrics::increment_counter!("tgi_request_failure", "err" => "template");
tracing::error!("{e}");
e
})
}
/// Add a new request to the queue and return a InferResponse
#[instrument(skip_all)]
pub(crate) async fn generate(
&self,
request: GenerateRequest,
) -> Result<InferResponse, InferError> {
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
// Create stream and keep semaphore permit as long as generate lives
let (_permit, _input_length, mut stream) = self.generate_stream(request).await?;
// Return values
let mut result_prefill = Vec::new();
let mut result_tokens = Vec::new();
let mut result_top_tokens = Vec::new();
let mut result_generated_text = None;
let mut result_start = None;
let mut result_queued = None;
// Iterate on stream
while let Some(response) = stream.next().await {
match response? {
// Add prefill tokens
InferStreamResponse::Prefill(prefill_tokens) => {
result_prefill = prefill_tokens;
}
// Push last token
InferStreamResponse::Intermediate { token, top_tokens } => {
result_tokens.push(token);
result_top_tokens.push(top_tokens);
}
// Final message
// Set return values
InferStreamResponse::End {
token,
generated_text,
start,
queued,
top_tokens,
} => {
result_tokens.push(token);
result_top_tokens.push(top_tokens);
result_generated_text = Some(generated_text);
result_start = Some(start);
result_queued = Some(queued)
}
}
}
// Check that we received a `InferStreamResponse::End` message
if let (Some(generated_text), Some(queued), Some(start)) =
(result_generated_text, result_queued, result_start)
{
Ok(InferResponse {
prefill: result_prefill,
_input_length,
tokens: result_tokens,
generated_text,
queued,
start,
top_tokens: if use_top_tokens {
result_top_tokens
} else {
Vec::new()
},
})
} else {
let err = InferError::IncompleteGeneration;
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
tracing::error!("{err}");
Err(err)
}
}
/// Add best_of new requests to the queue and return a InferResponse of the sequence with
/// the highest log probability per token
#[instrument(skip(self, request))]
pub(crate) async fn generate_best_of(
&self,
request: GenerateRequest,
best_of: usize,
) -> Result<(InferResponse, Vec<InferResponse>), InferError> {
// validate best_of parameter separately
let best_of = self.validation.validate_best_of(best_of)?;
// create multiple generate requests
let mut infer_responses: Vec<InferResponse> =
try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;
// get the sequence with the highest log probability per token
let mut max_index = 0;
let mut max_logprob: f32 = f32::MIN;
for (i, response) in infer_responses.iter().enumerate() {
// mean logprobs of the generated tokens
let sequence_logprob = response
.tokens
.iter()
.map(|token| token.logprob)
.sum::<f32>()
/ response.tokens.len() as f32;
// set best sequence
if sequence_logprob > max_logprob {
max_index = i;
max_logprob = sequence_logprob;
}
}
let best_response = infer_responses.remove(max_index);
Ok((best_response, infer_responses))
}
}
/// Raise a exception (custom function) used in the chat templates
fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
}
#[derive(Clone)]
struct ChatTemplate {
template: Template<'static, 'static>,
bos_token: Option<String>,
eos_token: Option<String>,
use_default_tool_template: bool,
}
impl ChatTemplate {
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
let mut env = Box::new(Environment::new());
// enable things like .strip() or .capitalize()
env.set_unknown_method_callback(pycompat::unknown_method_callback);
let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception);
// check if contains the tools variable within the template
let use_default_tool_template =
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
// leaking env and template_str as read-only, static resources for performance.
let template = Box::leak(env)
.template_from_str(Box::leak(template_str))
.unwrap();
Self {
template,
bos_token,
eos_token,
use_default_tool_template,
}
}
fn apply(
&self,
mut messages: Vec<Message>,
grammar_with_prompt: Option<(GrammarType, String)>,
) -> Result<String, InferError> {
if self.use_default_tool_template {
if let Some(last_message) = messages.last_mut() {
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
last_message.content.push(MessageChunk::Text(Text {
text: format!("\n---\n{}\n{}", tool_prompt, tools),
}));
}
}
}
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
self.template
.render(ChatTemplateInputs {
messages,
bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(),
add_generation_prompt: true,
tools: None,
tools_prompt: None,
})
.map_err(InferError::TemplateError)
}
}
pub struct ToolGrammar {}
impl ToolGrammar {
pub fn apply(
tools: Option<Vec<Tool>>,
tool_choice: Option<ToolType>,
) -> Result<Option<Tools>, InferError> {
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) {
// let tool_prompt = tool_prompt.unwrap_or_default();
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![req_tools
.iter()
.find(|tool| tool.function.name == *name)
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
.clone()]
}
ToolType::OneOf => req_tools.to_owned(),
};
// adds the error notification function for LLM feedback if required
let mut text_response_properties = Map::new();
text_response_properties.insert(
"error".to_string(),
serde_json::json!({
"type": "string",
"description": "The error or issue to notify"
}),
);
text_response_properties.insert(
"_name".to_string(),
serde_json::json!({
"type": "string",
"const": "notify_error"
}),
);
let functions: HashMap<String, serde_json::Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
// Clone the existing parameters, which are expected to be a JSON object
let mut params = if let Value::Object(params) = &func.arguments {
params.clone()
} else {
Map::new()
};
// Insert the function's description at the top level, outside of properties
params.insert(
"description".to_string(),
Value::String(func.description.clone().unwrap_or_default()),
);
// Ensure 'properties' exists and is an object
let properties = params
.entry("properties".to_string())
.or_insert_with(|| json!({}))
.as_object_mut()
.unwrap();
// Insert the constant for the function name inside 'properties'
properties.insert(
"_name".to_string(),
json!({
"type": "string",
"const": func.name.clone(),
// "description": "The name of the function"
}),
);
// Check if 'required' exists, and it is an array. If not, create an empty array.
let required = params
.entry("required".to_string())
.or_insert_with(|| json!([]))
.as_array_mut()
.unwrap();
// Add 'name' to the 'required' array if it is not already present
if !required.iter().any(|r| r == "_name") {
required.push(json!("_name"));
}
(func.name, Value::Object(params))
})
.chain([(
"notify_error".to_string(),
serde_json::json!({
"properties": text_response_properties,
"required": ["error", "_name"],
"type": "object"
}),
)])
.collect();
let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.chain(std::iter::once(FunctionRef {
ref_path: "#/$functions/notify_error".to_string(),
}))
.collect(),
},
};
return Ok(Some(tools));
}
// Err(InferError::ToolError("No tools provided".to_string()))
Ok(None)
}
}
/// Type alias for generation responses
pub(crate) type GenerateStreamResponse = (
OwnedSemaphorePermit,
u32, // input_length
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
);
#[derive(Debug)]
pub(crate) struct GeneratedText {
pub(crate) text: String,
pub(crate) generated_tokens: u32,
pub(crate) finish_reason: FinishReason,
pub(crate) seed: Option<u64>,
}
#[derive(Debug)]
pub(crate) enum InferStreamResponse {
// Optional first message
Prefill(Vec<PrefillToken>),
// Intermediate messages
Intermediate {
token: Token,
top_tokens: Vec<Token>,
},
// Last message
End {
token: Token,
top_tokens: Vec<Token>,
generated_text: GeneratedText,
start: Instant,
queued: Instant,
},
}
#[derive(Debug)]
pub(crate) struct InferResponse {
/// input_length is the input as perceived by the rust tokenizer in the
/// validation pathway. It is redundant with prefill.len() but prefill
/// has data only if the user asked for it. This will always be filled.
pub(crate) _input_length: u32,
pub(crate) prefill: Vec<PrefillToken>,
pub(crate) tokens: Vec<Token>,
pub(crate) generated_text: GeneratedText,
pub(crate) queued: Instant,
pub(crate) start: Instant,
pub(crate) top_tokens: Vec<Vec<Token>>,
}
#[derive(Debug, Error)]
pub enum InferError {
#[error("Request failed during generation: {0}")]
GenerationError(String),
#[error("Model is overloaded")]
Overloaded(#[from] TryAcquireError),
#[error("Input validation error: {0}")]
ValidationError(#[from] ValidationError),
#[error("Incomplete generation")]
IncompleteGeneration,
#[error("Template error: {0}")]
TemplateError(#[from] minijinja::Error),
#[error("Tool error: {0}")]
ToolError(String),
}
impl InferError {
pub(crate) fn error_type(&self) -> &str {
match self {
InferError::GenerationError(_) => "generation",
InferError::Overloaded(_) => "overloaded",
InferError::ValidationError(_) => "validation",
InferError::IncompleteGeneration => "incomplete_generation",
InferError::TemplateError(_) => "template_error",
InferError::ToolError(_) => "tool_error",
}
}
}

View File

@ -0,0 +1,4 @@
mod queue;
mod scheduler;
pub(crate) use scheduler::SchedulerV2;

View File

@ -1,10 +1,14 @@
use crate::infer::InferError; use crate::infer::{InferError, InferStreamResponse};
use crate::infer::InferStreamResponse; use crate::validation::{
use crate::validation::ValidGenerateRequest; ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
};
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min; use std::cmp::min;
use std::collections::VecDeque; use std::collections::VecDeque;
use text_generation_client::{Batch, Request}; use text_generation_client::v2::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use text_generation_client::ChunksToString;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant; use tokio::time::Instant;
use tracing::{info_span, instrument, Span}; use tracing::{info_span, instrument, Span};
@ -55,7 +59,6 @@ impl Queue {
Self { queue_sender } Self { queue_sender }
} }
/// Append an entry to the queue
#[instrument(skip_all)] #[instrument(skip_all)]
pub(crate) fn append(&self, entry: Entry) { pub(crate) fn append(&self, entry: Entry) {
// Send append command to the background task managing the state // Send append command to the background task managing the state
@ -278,10 +281,14 @@ impl State {
batch_requests.push(Request { batch_requests.push(Request {
id, id,
prefill_logprobs: entry.request.decoder_input_details, prefill_logprobs: entry.request.decoder_input_details,
inputs: entry.request.inputs.clone(), inputs: entry.request.inputs.chunks_to_string(),
truncate: entry.request.truncate, truncate: entry.request.truncate,
parameters: Some(entry.request.parameters.clone()), parameters: Some(NextTokenChooserParameters::from(
stopping_parameters: Some(entry.request.stopping_parameters.clone()), entry.request.parameters.clone(),
)),
stopping_parameters: Some(StoppingCriteriaParameters::from(
entry.request.stopping_parameters.clone(),
)),
top_n_tokens: entry.request.top_n_tokens, top_n_tokens: entry.request.top_n_tokens,
}); });
// Set batch_time // Set batch_time
@ -297,7 +304,7 @@ impl State {
// Empty batch // Empty batch
if batch_requests.is_empty() { if batch_requests.is_empty() {
tracing::debug!("Filterered out all entries"); tracing::debug!("Filtered out all entries");
return None; return None;
} }
@ -350,12 +357,46 @@ enum QueueCommand {
}, },
} }
impl From<ValidParameters> for NextTokenChooserParameters {
fn from(value: ValidParameters) -> Self {
let (grammar, grammar_type) = match value.grammar {
None => (String::new(), GrammarType::None),
Some(grammar) => match grammar {
ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json),
ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex),
},
};
Self {
temperature: value.temperature,
top_k: value.top_k,
top_p: value.top_p,
typical_p: value.typical_p,
do_sample: value.do_sample,
seed: value.seed,
repetition_penalty: value.repetition_penalty,
frequency_penalty: value.frequency_penalty,
watermark: value.watermark,
grammar,
grammar_type: grammar_type.into(),
}
}
}
impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
fn from(value: ValidStoppingParameters) -> Self {
Self {
max_new_tokens: value.max_new_tokens,
stop_sequences: value.stop_sequences,
ignore_eos_token: value.ignore_eos_token,
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use text_generation_client::{
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
};
use tracing::info_span; use tracing::info_span;
fn default_entry() -> ( fn default_entry() -> (
@ -366,11 +407,11 @@ mod tests {
let entry = Entry { let entry = Entry {
request: ValidGenerateRequest { request: ValidGenerateRequest {
inputs: String::new(), inputs: vec![],
input_length: 0, input_length: 0,
truncate: 0, truncate: 0,
decoder_input_details: false, decoder_input_details: false,
parameters: NextTokenChooserParameters { parameters: ValidParameters {
temperature: 0.0, temperature: 0.0,
top_k: 0, top_k: 0,
top_p: 0.0, top_p: 0.0,
@ -380,10 +421,9 @@ mod tests {
repetition_penalty: 0.0, repetition_penalty: 0.0,
frequency_penalty: 0.0, frequency_penalty: 0.0,
watermark: false, watermark: false,
grammar: String::new(), grammar: None,
grammar_type: ProtoGrammarType::None as i32,
}, },
stopping_parameters: StoppingCriteriaParameters { stopping_parameters: ValidStoppingParameters {
ignore_eos_token: false, ignore_eos_token: false,
max_new_tokens: 1, max_new_tokens: 1,
stop_sequences: vec![], stop_sequences: vec![],

View File

@ -1,78 +1,46 @@
/// Batching and inference logic /// Batching and inference logic
use crate::validation::{Validation, ValidationError}; use crate::infer::v2::queue::{Entry, Queue};
use crate::{ use crate::infer::{
ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler,
HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text, TextMessage, Token,
}; };
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; use crate::validation::ValidGenerateRequest;
use futures::future::try_join_all; use crate::{FinishReason, PrefillToken, Token};
use minijinja::{Environment, ErrorKind, Template};
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use serde_json::{json, Map, Value};
use std::collections::HashMap;
use std::sync::{ use std::sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
Arc, Arc,
}; };
use text_generation_client::{ use text_generation_client::v2::{Batch, CachedBatch, Generation, ShardedClient};
Batch, CachedBatch, ClientError, GeneratedText, Generation, ShardedClient, Tokens, use text_generation_client::ClientError;
};
use thiserror::Error;
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError}; use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit};
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
use tracing::{info_span, instrument, Instrument, Span}; use tracing::{info_span, instrument, Instrument, Span};
/// Inference struct pub(crate) struct SchedulerV2 {
#[derive(Clone)]
pub struct Infer {
/// Validation
validation: Validation,
/// Request queue /// Request queue
queue: Queue, queue: Queue,
/// Shared state /// Notify batcher on queue appends
shared: Arc<Shared>, batching_task_notifier: Arc<Notify>,
/// Chat template
chat_template: Option<ChatTemplate>,
/// Inference limit
limit_concurrent_requests: Arc<Semaphore>,
} }
/// Infer shared state impl SchedulerV2 {
struct Shared {
/// Batching background Tokio task notifier
batching_task: Notify,
}
/// Raise a exception (custom function) used in the chat templates
fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
}
impl Infer {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub(crate) fn new( pub(crate) fn new(
client: ShardedClient, client: ShardedClient,
validation: Validation,
waiting_served_ratio: f32, waiting_served_ratio: f32,
max_batch_prefill_tokens: u32, max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
max_concurrent_requests: usize,
requires_padding: bool, requires_padding: bool,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
generation_health: Arc<AtomicBool>, generation_health: Arc<AtomicBool>,
tokenizer_config: HubTokenizerConfig,
) -> Self { ) -> Self {
// Infer shared state
let queue = Queue::new(requires_padding, 16, window_size, speculate); let queue = Queue::new(requires_padding, 16, window_size, speculate);
let shared = Arc::new(Shared { let batching_task_notifier = Arc::new(Notify::new());
batching_task: Notify::new(),
});
// Spawn batching background task that contains all the inference logic // Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task( tokio::spawn(batching_task(
@ -83,68 +51,31 @@ impl Infer {
max_waiting_tokens, max_waiting_tokens,
max_batch_size, max_batch_size,
queue.clone(), queue.clone(),
shared.clone(), batching_task_notifier.clone(),
generation_health, generation_health,
)); ));
let chat_template = tokenizer_config
.chat_template
.and_then(|t| match t {
ChatTemplateVersions::Single(template) => Some(template),
ChatTemplateVersions::Multiple(templates) => templates
.into_iter()
.find(|t| t.name == "default")
.map(|t| t.template),
})
.map(|t| {
// .strip() is not supported in minijinja
let t = t.replace(".strip()", " | trim");
ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
});
// Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
Self { Self {
validation,
queue, queue,
shared, batching_task_notifier,
chat_template, }
limit_concurrent_requests: semaphore,
} }
} }
/// Add a new request to the queue and return a stream of InferStreamResponse impl Scheduler for SchedulerV2 {
#[instrument(skip_all)] #[instrument(skip_all)]
pub(crate) async fn generate_stream( fn schedule(
&self, &self,
request: GenerateRequest, request: ValidGenerateRequest,
permit: OwnedSemaphorePermit,
) -> Result<GenerateStreamResponse, InferError> { ) -> Result<GenerateStreamResponse, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore
let permit = self
.clone()
.limit_concurrent_requests
.try_acquire_owned()
.map_err(|err| {
metrics::increment_counter!("tgi_request_failure", "err" => "overloaded");
tracing::error!("{err}");
err
})?;
// Validate request
let valid_request = self.validation.validate(request).await.map_err(|err| {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
err
})?;
// MPSC channel to communicate with the background batching task // MPSC channel to communicate with the background batching task
let (response_tx, response_rx) = mpsc::unbounded_channel(); let (response_tx, response_rx) = mpsc::unbounded_channel();
let input_length = valid_request.input_length; let input_length = request.input_length;
// Append the request to the queue // Append the request to the queue
self.queue.append(Entry { self.queue.append(Entry {
request: valid_request, request,
response_tx, response_tx,
span: Span::current(), span: Span::current(),
temp_span: None, temp_span: None,
@ -154,7 +85,7 @@ impl Infer {
// Notify the background task that we have a new entry in the queue that needs // Notify the background task that we have a new entry in the queue that needs
// to be batched // to be batched
self.shared.batching_task.notify_one(); self.batching_task_notifier.notify_one();
// Return stream // Return stream
Ok(( Ok((
@ -163,343 +94,6 @@ impl Infer {
UnboundedReceiverStream::new(response_rx), UnboundedReceiverStream::new(response_rx),
)) ))
} }
/// Tokenizer the input
#[instrument(skip_all)]
pub(crate) async fn tokenize(
&self,
request: GenerateRequest,
) -> Result<Option<tokenizers::Encoding>, InferError> {
// Tokenize request
let inputs = request.inputs;
let truncate = request.parameters.truncate;
let encoding = self
.validation
.tokenize(inputs, truncate)
.await
.map_err(|err| {
tracing::error!("Tokenization {err}");
err
})?;
// Return Encoding
Ok(encoding.map(|(encoding, _)| encoding))
}
/// Apply the chat template to the chat request
#[instrument(skip_all)]
pub(crate) fn apply_chat_template(
&self,
messages: Vec<Message>,
grammar_with_prompt: Option<(GrammarType, String)>,
) -> Result<String, InferError> {
self.chat_template
.as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.apply(messages, grammar_with_prompt)
.map_err(|e| {
metrics::increment_counter!("tgi_request_failure", "err" => "template");
tracing::error!("{e}");
e
})
}
/// Add a new request to the queue and return a InferResponse
#[instrument(skip_all)]
pub(crate) async fn generate(
&self,
request: GenerateRequest,
) -> Result<InferResponse, InferError> {
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
// Create stream and keep semaphore permit as long as generate lives
let (_permit, _input_length, mut stream) = self.generate_stream(request).await?;
// Return values
let mut result_prefill = Vec::new();
let mut result_tokens = Vec::new();
let mut result_top_tokens = Vec::new();
let mut result_generated_text = None;
let mut result_start = None;
let mut result_queued = None;
// Iterate on stream
while let Some(response) = stream.next().await {
match response? {
// Add prefill tokens
InferStreamResponse::Prefill(tokens) => {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
result_prefill = tokens
.ids
.into_iter()
.zip(tokens.logprobs.into_iter())
.zip(tokens.texts.into_iter())
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
}
// Push last token
InferStreamResponse::Intermediate { token, top_tokens } => {
result_tokens.push(token);
result_top_tokens.push(top_tokens);
}
// Final message
// Set return values
InferStreamResponse::End {
token,
generated_text,
start,
queued,
top_tokens,
} => {
result_tokens.push(token);
result_top_tokens.push(top_tokens);
result_generated_text = Some(generated_text);
result_start = Some(start);
result_queued = Some(queued)
}
}
}
// Check that we received a `InferStreamResponse::End` message
if let (Some(generated_text), Some(queued), Some(start)) =
(result_generated_text, result_queued, result_start)
{
Ok(InferResponse {
prefill: result_prefill,
_input_length,
tokens: result_tokens,
generated_text,
queued,
start,
top_tokens: if use_top_tokens {
result_top_tokens
} else {
Vec::new()
},
})
} else {
let err = InferError::IncompleteGeneration;
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
tracing::error!("{err}");
Err(err)
}
}
/// Add best_of new requests to the queue and return a InferResponse of the sequence with
/// the highest log probability per token
#[instrument(skip(self, request))]
pub(crate) async fn generate_best_of(
&self,
request: GenerateRequest,
best_of: usize,
) -> Result<(InferResponse, Vec<InferResponse>), InferError> {
// validate best_of parameter separately
let best_of = self.validation.validate_best_of(best_of)?;
// create multiple generate requests
let mut infer_responses: Vec<InferResponse> =
try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;
// get the sequence with the highest log probability per token
let mut max_index = 0;
let mut max_logprob: f32 = f32::MIN;
for (i, response) in infer_responses.iter().enumerate() {
// mean logprobs of the generated tokens
let sequence_logprob = response
.tokens
.iter()
.map(|token| token.logprob)
.sum::<f32>()
/ response.tokens.len() as f32;
// set best sequence
if sequence_logprob > max_logprob {
max_index = i;
max_logprob = sequence_logprob;
}
}
let best_response = infer_responses.remove(max_index);
Ok((best_response, infer_responses))
}
}
#[derive(Clone)]
struct ChatTemplate {
template: Template<'static, 'static>,
bos_token: Option<String>,
eos_token: Option<String>,
use_default_tool_template: bool,
}
impl ChatTemplate {
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
let mut env = Box::new(Environment::new());
let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception);
// check if contains the tools variable within the template
let use_default_tool_template =
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
// leaking env and template_str as read-only, static resources for performance.
let template = Box::leak(env)
.template_from_str(Box::leak(template_str))
.unwrap();
Self {
template,
bos_token,
eos_token,
use_default_tool_template,
}
}
fn apply(
&self,
mut messages: Vec<Message>,
grammar_with_prompt: Option<(GrammarType, String)>,
) -> Result<String, InferError> {
if self.use_default_tool_template {
if let Some(last_message) = messages.last_mut() {
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
last_message.content.push(MessageChunk::Text(Text {
text: format!("\n---\n{}\n{}", tool_prompt, tools),
}));
}
}
}
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
self.template
.render(ChatTemplateInputs {
messages,
bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(),
add_generation_prompt: true,
tools: None,
tools_prompt: None,
})
.map_err(InferError::TemplateError)
}
}
pub struct ToolGrammar {}
impl ToolGrammar {
pub fn apply(
tools: Option<Vec<Tool>>,
tool_choice: Option<ToolType>,
) -> Result<Option<Tools>, InferError> {
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) {
// let tool_prompt = tool_prompt.unwrap_or_default();
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![req_tools
.iter()
.find(|tool| tool.function.name == *name)
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
.clone()]
}
ToolType::OneOf => req_tools.to_owned(),
};
// adds the error notification function for LLM feedback if required
let mut text_response_properties = Map::new();
text_response_properties.insert(
"error".to_string(),
serde_json::json!({
"type": "string",
"description": "The error or issue to notify"
}),
);
text_response_properties.insert(
"_name".to_string(),
serde_json::json!({
"type": "string",
"const": "notify_error"
}),
);
let functions: HashMap<String, serde_json::Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
// Clone the existing parameters, which are expected to be a JSON object
let mut params = if let Value::Object(params) = &func.arguments {
params.clone()
} else {
Map::new()
};
// Insert the function's description at the top level, outside of properties
params.insert(
"description".to_string(),
Value::String(func.description.clone().unwrap_or_default()),
);
// Ensure 'properties' exists and is an object
let properties = params
.entry("properties".to_string())
.or_insert_with(|| json!({}))
.as_object_mut()
.unwrap();
// Insert the constant for the function name inside 'properties'
properties.insert(
"_name".to_string(),
json!({
"type": "string",
"const": func.name.clone(),
// "description": "The name of the function"
}),
);
// Check if 'required' exists, and it is an array. If not, create an empty array.
let required = params
.entry("required".to_string())
.or_insert_with(|| json!([]))
.as_array_mut()
.unwrap();
// Add 'name' to the 'required' array if it is not already present
if !required.iter().any(|r| r == "_name") {
required.push(json!("_name"));
}
(func.name, Value::Object(params))
})
.chain([(
"notify_error".to_string(),
serde_json::json!({
"properties": text_response_properties,
"required": ["error", "_name"],
"type": "object"
}),
)])
.collect();
let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.chain(std::iter::once(FunctionRef {
ref_path: "#/$functions/notify_error".to_string(),
}))
.collect(),
},
};
return Ok(Some(tools));
}
// Err(InferError::ToolError("No tools provided".to_string()))
Ok(None)
}
} }
/// Batching logic /// Batching logic
@ -507,7 +101,7 @@ impl ToolGrammar {
/// ///
/// Batches requests and sends them to the inference server /// Batches requests and sends them to the inference server
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
async fn batching_task( pub(crate) async fn batching_task(
mut client: ShardedClient, mut client: ShardedClient,
waiting_served_ratio: f32, waiting_served_ratio: f32,
max_batch_prefill_tokens: u32, max_batch_prefill_tokens: u32,
@ -515,13 +109,13 @@ async fn batching_task(
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
queue: Queue, queue: Queue,
shared: Arc<Shared>, notifier: Arc<Notify>,
generation_health: Arc<AtomicBool>, generation_health: Arc<AtomicBool>,
) { ) {
// Infinite loop // Infinite loop
loop { loop {
// Wait for a notification from the Infer struct // Wait for a notification from the Infer struct
shared.batching_task.notified().await; notifier.notified().await;
// Get the next batch from the queue // Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests // This batch might be smaller than the maximum batch size if there are not enough requests
@ -787,6 +381,16 @@ fn send_responses(
let mut stopped = false; let mut stopped = false;
if let Some(prefill_tokens) = generation.prefill_tokens { if let Some(prefill_tokens) = generation.prefill_tokens {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
let prefill_tokens = prefill_tokens
.ids
.into_iter()
.zip(prefill_tokens.logprobs)
.zip(prefill_tokens.texts)
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
// Send message // Send message
entry entry
.response_tx .response_tx
@ -837,7 +441,7 @@ fn send_responses(
entry.response_tx.send(Ok(InferStreamResponse::End { entry.response_tx.send(Ok(InferStreamResponse::End {
token, token,
top_tokens, top_tokens,
generated_text: generated_text.clone(), generated_text: GeneratedText::from(generated_text.clone()),
queued: entry.queue_time, queued: entry.queue_time,
start: entry.batch_time.unwrap(), start: entry.batch_time.unwrap(),
}))?; }))?;
@ -872,64 +476,21 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
}); });
} }
#[derive(Debug)] impl From<text_generation_client::v2::GeneratedText> for GeneratedText {
pub(crate) enum InferStreamResponse { fn from(value: text_generation_client::v2::GeneratedText) -> Self {
// Optional first message let v2_finish_reason =
Prefill(Tokens), text_generation_client::v2::FinishReason::try_from(value.finish_reason).unwrap();
// Intermediate messages let finish_reason = match v2_finish_reason {
Intermediate { text_generation_client::v2::FinishReason::Length => FinishReason::Length,
token: Token, text_generation_client::v2::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
top_tokens: Vec<Token>, text_generation_client::v2::FinishReason::StopSequence => FinishReason::StopSequence,
}, };
// Last message
End {
token: Token,
top_tokens: Vec<Token>,
generated_text: GeneratedText,
start: Instant,
queued: Instant,
},
}
#[derive(Debug)] Self {
pub(crate) struct InferResponse { text: value.text,
/// input_length is the input as perceived by the rust tokenizer in the generated_tokens: value.generated_tokens,
/// validation pathway. It is redundant with prefill.len() but prefill finish_reason,
/// has data only if the user asked for it. This will always be filled. seed: value.seed,
pub(crate) _input_length: u32,
pub(crate) prefill: Vec<PrefillToken>,
pub(crate) tokens: Vec<Token>,
pub(crate) generated_text: GeneratedText,
pub(crate) queued: Instant,
pub(crate) start: Instant,
pub(crate) top_tokens: Vec<Vec<Token>>,
}
#[derive(Debug, Error)]
pub enum InferError {
#[error("Request failed during generation: {0}")]
GenerationError(String),
#[error("Model is overloaded")]
Overloaded(#[from] TryAcquireError),
#[error("Input validation error: {0}")]
ValidationError(#[from] ValidationError),
#[error("Incomplete generation")]
IncompleteGeneration,
#[error("Template error: {0}")]
TemplateError(#[from] minijinja::Error),
#[error("Tool error: {0}")]
ToolError(String),
}
impl InferError {
pub(crate) fn error_type(&self) -> &str {
match self {
InferError::GenerationError(_) => "generation",
InferError::Overloaded(_) => "overloaded",
InferError::ValidationError(_) => "validation",
InferError::IncompleteGeneration => "incomplete_generation",
InferError::TemplateError(_) => "template_error",
InferError::ToolError(_) => "tool_error",
} }
} }
} }

View File

@ -0,0 +1,136 @@
use std::cmp::min;
use tokio::sync::{mpsc, oneshot};
#[derive(Debug, Clone)]
pub(crate) struct BlockAllocation {
pub blocks: Vec<u32>,
pub slots: Vec<u32>,
block_allocator: BlockAllocator,
}
impl Drop for BlockAllocation {
fn drop(&mut self) {
self.block_allocator.free(self.blocks.clone())
}
}
#[derive(Debug, Clone)]
pub(crate) struct BlockAllocator {
/// Channel to communicate with the background task
block_allocator: mpsc::UnboundedSender<BlockAllocatorCommand>,
}
impl BlockAllocator {
pub(crate) fn new(
max_batch_total_tokens: u32,
block_size: u32,
window_size: Option<u32>,
) -> Self {
// Create channel
let (sender, receiver) = mpsc::unbounded_channel();
// Launch background queue task
tokio::spawn(block_allocator_task(
max_batch_total_tokens / block_size,
block_size,
window_size,
receiver,
));
Self {
block_allocator: sender,
}
}
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> {
let (response_sender, response_receiver) = oneshot::channel();
self.block_allocator
.send(BlockAllocatorCommand::Allocate {
tokens,
response_sender,
})
.unwrap();
response_receiver
.await
.unwrap()
.map(|(blocks, slots)| BlockAllocation {
blocks,
slots,
block_allocator: self.clone(),
})
}
pub(crate) fn free(&self, blocks: Vec<u32>) {
self.block_allocator
.send(BlockAllocatorCommand::Free { blocks })
.unwrap();
}
}
async fn block_allocator_task(
blocks: u32,
block_size: u32,
window_size: Option<u32>,
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
) {
// Block 0 is reserved for health checks
let mut free_blocks: Vec<u32> = (1..blocks).collect();
while let Some(cmd) = receiver.recv().await {
match cmd {
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks),
BlockAllocatorCommand::Allocate {
tokens,
response_sender,
} => {
// Apply window size
let (required_blocks, repeats) = {
let (tokens, repeats) = match window_size {
None => (tokens, 1),
Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size;
let tokens = min(tokens, window_size);
(tokens, repeats as usize)
}
};
// Pad to a multiple of block size
let required_blocks = (tokens + block_size - 1) / block_size;
(required_blocks, repeats)
};
let tokens = tokens as usize;
let allocation = if required_blocks > free_blocks.len() as u32 {
None
} else {
let blocks =
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
let mut slots = Vec::with_capacity(
(required_blocks * block_size * repeats as u32) as usize,
);
'slots: for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * block_size)..((block_id + 1) * block_size) {
slots.push(s);
if slots.len() == tokens {
break 'slots;
}
}
}
Some((blocks, slots))
};
response_sender.send(allocation).unwrap();
}
}
}
}
#[derive(Debug)]
enum BlockAllocatorCommand {
Free {
blocks: Vec<u32>,
},
Allocate {
tokens: u32,
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>,
},
}

View File

@ -0,0 +1,5 @@
mod block_allocator;
mod queue;
mod scheduler;
pub(crate) use scheduler::SchedulerV3;

View File

@ -0,0 +1,730 @@
use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator};
use crate::infer::InferError;
use crate::infer::InferStreamResponse;
use crate::validation::{
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
};
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::{max, min};
use std::collections::VecDeque;
use text_generation_client::v3::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use text_generation_client::ChunksToString;
use text_generation_client::Input;
use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant;
use tracing::{info_span, instrument, Instrument, Span};
/// Queue entry
#[derive(Debug)]
pub(crate) struct Entry {
/// Request
pub request: ValidGenerateRequest,
/// Response sender to communicate between the Infer struct and the batching_task
pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,
/// Span that will live as long as entry
pub span: Span,
/// Temporary span used as a guard when logging inference, wait times...
pub temp_span: Option<Span>,
/// Instant when this entry was queued
pub queue_time: Instant,
/// Instant when this entry was added to a batch
pub batch_time: Option<Instant>,
/// Block Allocation
pub block_allocation: Option<BlockAllocation>,
}
/// Request Queue
#[derive(Debug, Clone)]
pub(crate) struct Queue {
/// Channel to communicate with the background queue task
queue_sender: mpsc::UnboundedSender<QueueCommand>,
}
impl Queue {
pub(crate) fn new(
requires_padding: bool,
block_size: u32,
window_size: Option<u32>,
speculate: u32,
max_batch_total_tokens: u32,
) -> Self {
// Create channel
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
// Launch background queue task
tokio::spawn(queue_task(
requires_padding,
block_size,
window_size,
speculate,
max_batch_total_tokens,
queue_receiver,
));
Self { queue_sender }
}
/// Append an entry to the queue
#[instrument(skip_all)]
pub(crate) fn append(&self, entry: Entry) {
// Send append command to the background task managing the state
// Unwrap is safe here
self.queue_sender
.send(QueueCommand::Append(Box::new(entry), Span::current()))
.unwrap();
}
// Get the next batch
#[instrument(skip(self))]
pub(crate) async fn next_batch(
&self,
min_size: Option<usize>,
max_size: Option<usize>,
prefill_token_budget: u32,
token_budget: u32,
) -> Option<NextBatch> {
// Create response channel
let (response_sender, response_receiver) = oneshot::channel();
// Send next batch command to the background task managing the state
// Unwrap is safe here
self.queue_sender
.send(QueueCommand::NextBatch {
min_size,
max_size,
prefill_token_budget,
token_budget,
response_sender,
span: Span::current(),
})
.unwrap();
// Await on response channel
// Unwrap is safe here
response_receiver.await.unwrap()
}
}
// Background task responsible of the queue state
async fn queue_task(
requires_padding: bool,
block_size: u32,
window_size: Option<u32>,
speculate: u32,
max_batch_total_tokens: u32,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
) {
let mut state = State::new(
requires_padding,
block_size,
window_size,
speculate,
max_batch_total_tokens,
);
while let Some(cmd) = receiver.recv().await {
match cmd {
QueueCommand::Append(entry, span) => {
span.in_scope(|| state.append(*entry));
metrics::increment_gauge!("tgi_queue_size", 1.0);
}
QueueCommand::NextBatch {
min_size,
max_size,
prefill_token_budget,
token_budget,
response_sender,
span,
} => {
let next_batch = state
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
.instrument(span)
.await;
response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
}
}
}
}
/// Queue State
#[derive(Debug)]
struct State {
/// Queue entries organized in a Vec
entries: VecDeque<(u64, Entry)>,
/// Id of the next entry
next_id: u64,
/// Id of the next batch
next_batch_id: u64,
/// Paged Attention block size
block_size: u32,
/// Sliding window
window_size: Option<u32>,
/// Speculation amount
speculate: u32,
/// Paged Attention Block Allocation
block_allocator: Option<BlockAllocator>,
}
impl State {
fn new(
requires_padding: bool,
block_size: u32,
window_size: Option<u32>,
speculate: u32,
max_batch_total_tokens: u32,
) -> Self {
let block_allocator = (!requires_padding)
.then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size));
Self {
entries: VecDeque::with_capacity(128),
next_id: 0,
next_batch_id: 0,
block_size,
window_size,
speculate,
block_allocator,
}
}
/// Append an entry to the queue
fn append(&mut self, mut entry: Entry) {
// Create a span that will live as long as the entry is in the queue waiting to be batched
let queue_span = info_span!(parent: &entry.span, "queued");
entry.temp_span = Some(queue_span);
// Push entry in the queue
self.entries.push_back((self.next_id, entry));
self.next_id += 1;
}
// Get the next batch
async fn next_batch(
&mut self,
min_size: Option<usize>,
max_size: Option<usize>,
prefill_token_budget: u32,
token_budget: u32,
) -> Option<NextBatch> {
if self.entries.is_empty() {
tracing::debug!("No queue");
return None;
}
// Check if we have enough entries
if let Some(min_size) = min_size {
if self.entries.len() < min_size {
tracing::debug!("Not enough entries");
return None;
}
}
// Pad prefill_token_budget to be a multiple of block size
let prefill_token_budget =
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
// Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
next_batch_span.follows_from(&Span::current());
let mut batch_requests = Vec::with_capacity(self.entries.len());
let mut batch_entries =
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
let mut max_input_length = 0;
let mut prefill_tokens: u32 = 0;
let mut decode_tokens: u32 = 0;
let mut max_blocks = 0;
// Pop entries starting from the front of the queue
'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() {
// Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client)
if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
tracing::debug!("Dropping entry");
continue;
}
let block_allocation = match &self.block_allocator {
None => {
// We pad to max input length in the Python shards
// We need to take these padding tokens into the equation
max_input_length = max_input_length.max(entry.request.input_length);
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length;
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
if prefill_tokens > prefill_token_budget || total_tokens > token_budget {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
self.entries.push_front((id, entry));
break 'entry_loop;
}
None
}
Some(block_allocator) => {
prefill_tokens += entry.request.input_length;
let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens,
Some(window_size) => min(
window_size.saturating_sub(entry.request.input_length),
entry.request.stopping_parameters.max_new_tokens,
),
};
decode_tokens += max_new_tokens;
if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
{
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
self.entries.push_front((id, entry));
break;
}
let tokens = entry.request.input_length
+ entry.request.stopping_parameters.max_new_tokens
+ self.speculate
- 1;
match block_allocator.allocate(tokens).await {
None => {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: not enough free blocks");
self.entries.push_front((id, entry));
break 'entry_loop;
}
Some(block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
Some(block_allocation)
}
}
}
};
tracing::debug!("Accepting entry");
// Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.temp_span = Some(entry_batch_span);
let (blocks, slots) = match &block_allocation {
None => (Vec::new(), Vec::new()),
Some(block_allocation) => (
block_allocation.blocks.clone(),
block_allocation.slots.clone(),
),
};
entry.block_allocation = block_allocation;
batch_requests.push(Request {
id,
prefill_logprobs: entry.request.decoder_input_details,
input_chunks: Some(Input {
chunks: entry.request.inputs.clone(),
}),
inputs: entry.request.inputs.chunks_to_string(),
truncate: entry.request.truncate,
parameters: Some(NextTokenChooserParameters::from(
entry.request.parameters.clone(),
)),
stopping_parameters: Some(StoppingCriteriaParameters::from(
entry.request.stopping_parameters.clone(),
)),
top_n_tokens: entry.request.top_n_tokens,
blocks,
slots,
});
// Set batch_time
entry.batch_time = Some(Instant::now());
// Insert in batch_entries IntMap
batch_entries.insert(id, entry);
// Check if max_size
if Some(batch_requests.len()) == max_size {
break;
}
}
// Empty batch
if batch_requests.is_empty() {
tracing::debug!("Filterered out all entries");
return None;
}
// Check if our batch is big enough
if let Some(min_size) = min_size {
// Batch is too small
if batch_requests.len() < min_size {
// Add back entries to the queue in the correct order
for r in batch_requests.into_iter().rev() {
let id = r.id;
let entry = batch_entries.remove(&id).unwrap();
self.entries.push_front((id, entry));
}
return None;
}
}
// Final batch size
let size = batch_requests.len() as u32;
next_batch_span.record("batch_size", size);
let batch = Batch {
id: self.next_batch_id,
requests: batch_requests,
size,
max_tokens: (prefill_tokens + decode_tokens),
max_blocks,
};
// Increment batch id
self.next_batch_id += 1;
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
Some((batch_entries, batch, next_batch_span))
}
}
type NextBatch = (IntMap<u64, Entry>, Batch, Span);
#[derive(Debug)]
enum QueueCommand {
Append(Box<Entry>, Span),
NextBatch {
min_size: Option<usize>,
max_size: Option<usize>,
prefill_token_budget: u32,
token_budget: u32,
response_sender: oneshot::Sender<Option<NextBatch>>,
span: Span,
},
}
impl From<ValidParameters> for NextTokenChooserParameters {
fn from(value: ValidParameters) -> Self {
let (grammar, grammar_type) = match value.grammar {
None => (String::new(), GrammarType::None),
Some(grammar) => match grammar {
ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json),
ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex),
},
};
Self {
temperature: value.temperature,
top_k: value.top_k,
top_p: value.top_p,
typical_p: value.typical_p,
do_sample: value.do_sample,
seed: value.seed,
repetition_penalty: value.repetition_penalty,
frequency_penalty: value.frequency_penalty,
watermark: value.watermark,
grammar,
grammar_type: grammar_type.into(),
}
}
}
impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
fn from(value: ValidStoppingParameters) -> Self {
Self {
max_new_tokens: value.max_new_tokens,
stop_sequences: value.stop_sequences,
ignore_eos_token: value.ignore_eos_token,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tracing::info_span;
fn default_entry() -> (
Entry,
mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,
) {
let (response_tx, receiver_tx) = mpsc::unbounded_channel();
let entry = Entry {
request: ValidGenerateRequest {
inputs: vec![],
input_length: 0,
truncate: 0,
decoder_input_details: false,
parameters: ValidParameters {
temperature: 0.0,
top_k: 0,
top_p: 0.0,
typical_p: 0.0,
do_sample: false,
seed: 0,
repetition_penalty: 0.0,
frequency_penalty: 0.0,
watermark: false,
grammar: None,
},
stopping_parameters: ValidStoppingParameters {
ignore_eos_token: false,
max_new_tokens: 1,
stop_sequences: vec![],
},
top_n_tokens: 0,
},
response_tx,
span: info_span!("entry"),
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
block_allocation: None,
};
(entry, receiver_tx)
}
#[tokio::test]
async fn test_append() {
let mut state = State::new(false, 1, None, 0, 16);
let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0);
assert_eq!(state.entries.len(), 0);
state.append(entry);
assert_eq!(state.next_id, 1);
assert_eq!(state.entries.len(), 1);
let (id, _) = state.entries.remove(0).unwrap();
assert_eq!(id, 0);
}
#[tokio::test]
async fn test_next_batch_empty() {
let mut state = State::new(false, 1, None, 0, 16);
assert!(state.next_batch(None, None, 1, 1).await.is_none());
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
}
#[tokio::test]
async fn test_next_batch_min_size() {
let mut state = State::new(false, 1, None, 0, 16);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1));
assert!(entries.get(&0).unwrap().batch_time.is_some());
assert!(entries.get(&1).unwrap().batch_time.is_some());
assert_eq!(batch.id, 0);
assert_eq!(batch.size, 2);
assert_eq!(state.next_id, 2);
assert_eq!(state.entries.len(), 0);
assert_eq!(state.next_batch_id, 1);
let (entry3, _guard3) = default_entry();
state.append(entry3);
assert!(state.next_batch(Some(2), None, 2, 2).await.is_none());
assert_eq!(state.next_id, 3);
assert_eq!(state.entries.len(), 1);
let (id, _) = state.entries.remove(0).unwrap();
assert_eq!(id, 2);
}
#[tokio::test]
async fn test_next_batch_max_size() {
let mut state = State::new(false, 1, None, 0, 16);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap();
assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0));
assert!(entries.get(&0).unwrap().batch_time.is_some());
assert_eq!(batch.id, 0);
assert_eq!(batch.size, 1);
assert_eq!(state.next_id, 2);
assert_eq!(state.entries.len(), 1);
assert_eq!(state.next_batch_id, 1);
}
#[tokio::test]
async fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, None, 0, 2);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap();
assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0);
assert_eq!(batch.size, 1);
assert_eq!(state.next_id, 2);
assert_eq!(state.entries.len(), 1);
assert_eq!(state.next_batch_id, 1);
let (entry3, _guard3) = default_entry();
state.append(entry3);
let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2));
assert_eq!(batch.id, 1);
assert_eq!(batch.size, 2);
assert_eq!(state.next_id, 3);
assert_eq!(state.entries.len(), 0);
assert_eq!(state.next_batch_id, 2);
}
#[tokio::test]
async fn test_queue_append() {
let queue = Queue::new(false, 1, None, 0, 16);
let (entry, _guard) = default_entry();
queue.append(entry);
}
#[tokio::test]
async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, 1, None, 0, 16);
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
}
#[tokio::test]
async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, 1, None, 0, 16);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1));
assert!(entries.get(&0).unwrap().batch_time.is_some());
assert!(entries.get(&1).unwrap().batch_time.is_some());
assert_eq!(batch.id, 0);
assert_eq!(batch.size, 2);
let (entry3, _guard3) = default_entry();
queue.append(entry3);
// Not enough requests pending
assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none());
// Not enough token budget
assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none());
// Ok
let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap();
assert_eq!(entries2.len(), 1);
assert!(entries2.contains_key(&2));
assert!(entries2.get(&2).unwrap().batch_time.is_some());
assert_eq!(batch2.id, 1);
assert_eq!(batch2.size, 1);
}
#[tokio::test]
async fn test_queue_next_batch_max_size() {
let queue = Queue::new(false, 1, None, 0, 16);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap();
assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0));
assert!(entries.get(&0).unwrap().batch_time.is_some());
assert_eq!(batch.id, 0);
assert_eq!(batch.size, 1);
}
#[tokio::test]
async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, 1, None, 0, 16);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap();
assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0);
assert_eq!(batch.size, 1);
let (entry3, _guard3) = default_entry();
queue.append(entry3);
let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2));
assert_eq!(batch.id, 1);
assert_eq!(batch.size, 2);
}
#[tokio::test]
async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, None, 2, 16);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
queue.append(entry2);
// Budget of 1 is not enough
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1));
assert_eq!(batch.id, 0);
assert_eq!(batch.size, 2);
}
#[tokio::test]
async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, 1, None, 0, 16);
let (entry, _) = default_entry();
queue.append(entry);
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
}
}

File diff suppressed because it is too large Load Diff

247
router/src/kserve.rs Normal file
View File

@ -0,0 +1,247 @@
use crate::{
default_parameters,
server::{generate_internal, ComputeType},
Deserialize, ErrorResponse, GenerateParameters, GenerateRequest, Infer, Serialize, ToSchema,
};
use axum::extract::{Extension, Path};
use axum::response::{IntoResponse, Response};
use axum::Json;
use futures::stream::FuturesUnordered;
use futures::TryStreamExt;
use reqwest::header::HeaderMap;
use reqwest::StatusCode;
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct OutputChunk {
pub name: String,
pub shape: Vec<usize>,
pub datatype: String,
pub data: Vec<u8>,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct InferenceOutput {
pub id: String,
pub outputs: Vec<OutputChunk>,
}
#[derive(Debug, Deserialize, ToSchema)]
pub(crate) struct InferenceRequest {
pub id: String,
#[serde(default = "default_parameters")]
pub parameters: GenerateParameters,
pub inputs: Vec<Input>,
pub outputs: Vec<Output>,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub(crate) struct Input {
pub name: String,
pub shape: Vec<usize>,
pub datatype: String,
pub data: Vec<u8>,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub(crate) struct Output {
pub name: String,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct LiveResponse {
pub live: bool,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct ReadyResponse {
pub live: bool,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct MetadataServerResponse {
pub name: String,
pub version: String,
pub extensions: Vec<String>,
}
// Routes
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/v2/health/live",
responses(
(status = 200, description = "Service is live", body = LiveReponse),
(status = 404, description = "Service not found", body = ErrorResponse,
example = json!({"error": "No response"}))
)
)]
pub async fn kserve_health_live() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = LiveResponse { live: true };
Ok((HeaderMap::new(), Json(data)).into_response())
}
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/v2/health/ready",
responses(
(status = 200, description = "Service is ready", body = ReadyResponse),
(status = 404, description = "Service not found", body = ErrorResponse,
example = json!({"error": "No response"}))
)
)]
pub async fn kserve_health_ready() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = ReadyResponse { live: true };
Ok((HeaderMap::new(), Json(data)).into_response())
}
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/v2",
responses(
(status = 200, description = "Metadata retrieved", body = MetadataServerResponse),
(status = 404, description = "Service not found", body = ErrorResponse,
example = json!({"error": "No response"}))
)
)]
pub async fn kerve_server_metadata() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = MetadataServerResponse {
name: "text-generation-inference".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
extensions: vec![
"health".to_string(),
"models".to_string(),
"metrics".to_string(),
],
};
Ok((HeaderMap::new(), Json(data)).into_response())
}
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/v2/models/{model_name}/versions/{model_version}",
responses(
(status = 200, description = "Model version metadata retrieved", body = MetadataServerResponse),
(status = 404, description = "Model or version not found", body = ErrorResponse,
example = json!({"error": "No response"}))
)
)]
pub async fn kserve_model_metadata(
Path((model_name, model_version)): Path<(String, String)>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = MetadataServerResponse {
name: model_name,
version: model_version,
extensions: vec!["infer".to_string(), "ready".to_string()],
};
Ok((HeaderMap::new(), Json(data)).into_response())
}
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/v2/models/{model_name}/versions/{model_version}/infer",
request_body = Json<InferenceRequest>,
responses(
(status = 200, description = "Inference executed successfully", body = InferenceOutput),
(status = 404, description = "Model or version not found", body = ErrorResponse,
example = json!({"error": "No response"}))
)
)]
pub async fn kserve_model_infer(
infer: Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Json(payload): Json<InferenceRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let id = payload.id.clone();
let str_inputs = payload
.inputs
.iter()
.map(|input| {
std::str::from_utf8(&input.data).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: e.to_string(),
error_type: "utf8".to_string(),
}),
)
})
})
.collect::<Result<Vec<_>, _>>()?;
if str_inputs.len() != payload.outputs.len() {
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Inputs and outputs length mismatch".to_string(),
error_type: "length mismatch".to_string(),
}),
));
}
let output_chunks = str_inputs
.iter()
.zip(&payload.outputs)
.map(|(str_input, output)| {
let generate_request = GenerateRequest {
inputs: str_input.to_string(),
parameters: payload.parameters.clone(),
};
let infer = infer.clone();
let compute_type = compute_type.clone();
let span = tracing::Span::current();
async move {
generate_internal(infer, compute_type, Json(generate_request), span)
.await
.map(|(_, Json(generation))| {
let generation_as_bytes = generation.generated_text.as_bytes().to_vec();
OutputChunk {
name: output.name.clone(),
shape: vec![1, generation_as_bytes.len()],
datatype: "BYTES".to_string(),
data: generation_as_bytes,
}
})
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Incomplete generation".into(),
error_type: "Incomplete generation".into(),
}),
)
})
}
})
.collect::<FuturesUnordered<_>>()
.try_collect::<Vec<_>>()
.await?;
let inference_output = InferenceOutput {
id: id.clone(),
outputs: output_chunks,
};
Ok((HeaderMap::new(), Json(inference_output)).into_response())
}
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/v2/models/{model_name}/versions/{model_version}/ready",
responses(
(status = 200, description = "Model version is ready", body = ReadyResponse),
(status = 404, description = "Model or version not found", body = ErrorResponse,
example = json!({"error": "No response"}))
)
)]
pub async fn kserve_model_metadata_ready(
Path((_model_name, _model_version)): Path<(String, String)>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = ReadyResponse { live: true };
Ok((HeaderMap::new(), Json(data)).into_response())
}

View File

@ -1,27 +1,17 @@
pub mod config;
mod health;
/// Text Generation Inference Webserver /// Text Generation Inference Webserver
pub mod config;
mod infer; mod infer;
mod queue;
pub mod server; pub mod server;
mod validation; mod validation;
use infer::{Infer, InferError, InferStreamResponse}; #[cfg(feature = "kserve")]
use queue::{Entry, Queue}; mod kserve;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::sync::OwnedSemaphorePermit;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::warn; use tracing::warn;
use utoipa::ToSchema; use utoipa::ToSchema;
use validation::Validation; use validation::Validation;
/// Type alias for generation responses
pub(crate) type GenerateStreamResponse = (
OwnedSemaphorePermit,
u32, // input_length
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
);
#[derive(Clone, Deserialize, ToSchema)] #[derive(Clone, Deserialize, ToSchema)]
pub(crate) struct VertexInstance { pub(crate) struct VertexInstance {
#[schema(example = "What is Deep Learning?")] #[schema(example = "What is Deep Learning?")]
@ -80,6 +70,20 @@ impl HubTokenizerConfig {
} }
} }
#[derive(Debug, Clone, Deserialize, Default)]
pub struct HubProcessorConfig {
pub chat_template: Option<ChatTemplateVersions>,
pub image_seq_len: usize,
pub processor_class: Option<String>,
}
impl HubProcessorConfig {
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
let content = std::fs::read_to_string(filename).ok()?;
serde_json::from_str(&content).ok()
}
}
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)] #[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
#[serde(tag = "type", content = "value")] #[serde(tag = "type", content = "value")]
pub(crate) enum GrammarType { pub(crate) enum GrammarType {
@ -88,6 +92,7 @@ pub(crate) enum GrammarType {
/// JSON Schema is a declarative language that allows to annotate JSON documents /// JSON Schema is a declarative language that allows to annotate JSON documents
/// with types and descriptions. /// with types and descriptions.
#[serde(rename = "json")] #[serde(rename = "json")]
#[serde(alias = "json_object")]
#[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))] #[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))]
Json(serde_json::Value), Json(serde_json::Value),
#[serde(rename = "regex")] #[serde(rename = "regex")]
@ -144,7 +149,7 @@ pub struct Info {
#[schema(example = "4")] #[schema(example = "4")]
pub max_stop_sequences: usize, pub max_stop_sequences: usize,
#[schema(example = "1024")] #[schema(example = "1024")]
pub max_input_length: usize, pub max_input_tokens: usize,
#[schema(example = "2048")] #[schema(example = "2048")]
pub max_total_tokens: usize, pub max_total_tokens: usize,
#[schema(example = "1.2")] #[schema(example = "1.2")]
@ -402,6 +407,11 @@ pub struct CompletionRequest {
#[serde(default)] #[serde(default)]
#[schema(example = "1.0")] #[schema(example = "1.0")]
pub frequency_penalty: Option<f32>, pub frequency_penalty: Option<f32>,
/// Up to 4 sequences where the API will stop generating further tokens.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub stop: Option<Vec<String>>,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)] #[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
@ -785,6 +795,13 @@ pub(crate) struct ChatRequest {
#[schema(nullable = true, example = "null")] #[schema(nullable = true, example = "null")]
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")] #[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
pub tool_choice: Option<ToolType>, pub tool_choice: Option<ToolType>,
/// Response format constraints for the generation.
///
/// NOTE: A request can use `response_format` OR `tools` but not both.
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub response_format: Option<GrammarType>,
} }
fn default_tool_prompt() -> Option<String> { fn default_tool_prompt() -> Option<String> {
@ -1068,7 +1085,7 @@ pub struct SimpleToken {
stop: usize, stop: usize,
} }
#[derive(Serialize, ToSchema)] #[derive(Debug, Serialize, ToSchema)]
#[serde(rename_all(serialize = "snake_case"))] #[serde(rename_all(serialize = "snake_case"))]
#[schema(example = "Length")] #[schema(example = "Length")]
pub(crate) enum FinishReason { pub(crate) enum FinishReason {

View File

@ -12,15 +12,14 @@ use std::fs::File;
use std::io::BufReader; use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use text_generation_client::{ClientError, ShardedClient};
use text_generation_router::config::Config; use text_generation_router::config::Config;
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig}; use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig};
use thiserror::Error; use thiserror::Error;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tower_http::cors::AllowOrigin; use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer}; use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
/// App Configuration /// App Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -206,11 +205,18 @@ async fn main() -> Result<(), RouterError> {
}; };
// Load tokenizer and model info // Load tokenizer and model info
let (tokenizer_filename, config_filename, tokenizer_config_filename, model_info) = match api { let (
tokenizer_filename,
config_filename,
tokenizer_config_filename,
processor_config_filename,
model_info,
) = match api {
Type::None => ( Type::None => (
Some(local_path.join("tokenizer.json")), Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")), Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")), Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("processor_config.json")),
None, None,
), ),
Type::Api(api) => { Type::Api(api) => {
@ -226,6 +232,7 @@ async fn main() -> Result<(), RouterError> {
}; };
let config_filename = api_repo.get("config.json").await.ok(); let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
let model_info = if let Some(model_info) = get_model_info(&api_repo).await { let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
Some(model_info) Some(model_info)
@ -237,6 +244,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_filename, tokenizer_filename,
config_filename, config_filename,
tokenizer_config_filename, tokenizer_config_filename,
processor_config_filename,
model_info, model_info,
) )
} }
@ -250,6 +258,7 @@ async fn main() -> Result<(), RouterError> {
repo.get("tokenizer.json"), repo.get("tokenizer.json"),
repo.get("config.json"), repo.get("config.json"),
repo.get("tokenizer_config.json"), repo.get("tokenizer_config.json"),
repo.get("processor_config.json"),
None, None,
) )
} }
@ -286,6 +295,10 @@ async fn main() -> Result<(), RouterError> {
HubTokenizerConfig::default() HubTokenizerConfig::default()
}); });
let processor_config = processor_config_filename
.and_then(HubProcessorConfig::from_file)
.unwrap_or_default();
tracing::info!("Using config {config:?}"); tracing::info!("Using config {config:?}");
if tokenizer.is_none() { if tokenizer.is_none() {
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
@ -301,59 +314,6 @@ async fn main() -> Result<(), RouterError> {
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation", Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
}; };
// Instantiate sharded client from the master unix socket
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
.map_err(RouterError::Connection)?;
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.map_err(RouterError::Cache)?;
// Get info from the shard
let shard_info = sharded_client.info().await.map_err(RouterError::Info)?;
// Warmup model
tracing::info!("Warming up model");
let max_supported_batch_total_tokens = match sharded_client
.warmup(
max_input_tokens as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_size,
)
.await
.map_err(RouterError::Warmup)?
{
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
tracing::warn!("Model does not support automatic max batch total tokens");
max_batch_total_tokens
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
// Warn if user added his own max-batch-total-tokens as we will ignore it
if max_batch_total_tokens.is_some() {
tracing::warn!(
"`--max-batch-total-tokens` is deprecated for Flash \
Attention models."
);
tracing::warn!(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
if max_total_tokens as u32 > max_supported_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}")));
}
max_supported_batch_total_tokens
}
};
tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
tracing::info!("Connected");
// Determine the server port based on the feature and environment variable. // Determine the server port based on the feature and environment variable.
let port = if cfg!(feature = "google") { let port = if cfg!(feature = "google") {
std::env::var("AIP_HTTP_PORT") std::env::var("AIP_HTTP_PORT")
@ -373,8 +333,8 @@ async fn main() -> Result<(), RouterError> {
// Run server // Run server
server::run( server::run(
master_shard_uds_path,
model_info, model_info,
shard_info,
compat_return_full_text, compat_return_full_text,
max_concurrent_requests, max_concurrent_requests,
max_best_of, max_best_of,
@ -384,10 +344,9 @@ async fn main() -> Result<(), RouterError> {
max_total_tokens, max_total_tokens,
waiting_served_ratio, waiting_served_ratio,
max_batch_prefill_tokens, max_batch_prefill_tokens,
max_supported_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
max_batch_size, max_batch_size,
sharded_client,
tokenizer, tokenizer,
config, config,
validation_workers, validation_workers,
@ -397,6 +356,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
tokenizer_config, tokenizer_config,
processor_config,
messages_api_enabled, messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
@ -454,8 +414,21 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
} }
// Filter events with LOG_LEVEL // Filter events with LOG_LEVEL
let env_filter = let varname = "LOG_LEVEL";
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info")); let env_filter = if let Ok(log_level) = std::env::var(varname) {
// Override to avoid simple logs to be spammed with tokio level informations
let log_level = match &log_level[..] {
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
"info" => "text_generation_launcher=info,text_generation_router=info",
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
log_level => log_level,
};
EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.parse_lossy(log_level)
} else {
EnvFilter::new("info")
};
tracing_subscriber::registry() tracing_subscriber::registry()
.with(env_filter) .with(env_filter)
@ -529,16 +502,8 @@ pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConf
enum RouterError { enum RouterError {
#[error("Argument validation error: {0}")] #[error("Argument validation error: {0}")]
ArgumentValidation(String), ArgumentValidation(String),
#[error("Unable to connect to the Python model shards: {0}")] #[error("WebServer error: {0}")]
Connection(ClientError), WebServer(#[from] server::WebServerError),
#[error("Unable to clear the Python model shards cache: {0}")]
Cache(ClientError),
#[error("Unable to get the Python model shards info: {0}")]
Info(ClientError),
#[error("Unable to warmup the Python model shards: {0}")]
Warmup(ClientError),
#[error("Tokio runtime failed to start: {0}")] #[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error), Tokio(#[from] std::io::Error),
#[error("Axum webserver failed: {0}")]
Axum(#[from] axum::BoxError),
} }

View File

@ -1,13 +1,20 @@
use crate::config::Config;
/// HTTP Server logic /// HTTP Server logic
use crate::health::Health; use crate::config::Config;
use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar}; use crate::infer::v2::SchedulerV2;
use crate::infer::v3::SchedulerV3;
use crate::infer::{HealthCheck, Scheduler};
use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar};
#[cfg(feature = "kserve")]
use crate::kserve::{
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
kserve_model_metadata, kserve_model_metadata_ready,
};
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
GenerateResponse, GrammarType, HubModelInfo, HubTokenizerConfig, Infer, Info, Message, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info,
PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse,
Validation, Usage, Validation,
}; };
use crate::{ use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
@ -34,7 +41,8 @@ use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicBool;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{ShardInfo, ShardedClient}; use text_generation_client::{v2, v3, ClientError, ShardInfo};
use thiserror::Error;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::select; use tokio::select;
use tokio::signal; use tokio::signal;
@ -115,7 +123,9 @@ example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
)] )]
#[instrument(skip(health))] #[instrument(skip(health))]
/// Health check method /// Health check method
async fn health(mut health: Extension<Health>) -> Result<(), (StatusCode, Json<ErrorResponse>)> { async fn health(
mut health: Extension<HealthCheck>,
) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
match health.check().await { match health.check().await {
true => Ok(()), true => Ok(()),
false => Err(( false => Err((
@ -167,7 +177,7 @@ async fn generate(
generate_internal(infer, ComputeType(compute_type), Json(req), span).await generate_internal(infer, ComputeType(compute_type), Json(req), span).await
} }
async fn generate_internal( pub(crate) async fn generate_internal(
infer: Extension<Infer>, infer: Extension<Infer>,
ComputeType(compute_type): ComputeType, ComputeType(compute_type): ComputeType,
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
@ -213,9 +223,7 @@ async fn generate_internal(
BestOfSequence { BestOfSequence {
generated_text: output_text, generated_text: output_text,
finish_reason: FinishReason::from( finish_reason: response.generated_text.finish_reason,
response.generated_text.finish_reason,
),
generated_tokens: response.generated_text.generated_tokens, generated_tokens: response.generated_text.generated_tokens,
prefill: response.prefill, prefill: response.prefill,
tokens: response.tokens, tokens: response.tokens,
@ -227,7 +235,7 @@ async fn generate_internal(
}); });
Some(Details { Some(Details {
finish_reason: FinishReason::from(response.generated_text.finish_reason), finish_reason: response.generated_text.finish_reason,
generated_tokens: response.generated_text.generated_tokens, generated_tokens: response.generated_text.generated_tokens,
prefill: response.prefill, prefill: response.prefill,
tokens: response.tokens, tokens: response.tokens,
@ -468,7 +476,7 @@ async fn generate_stream_internal(
// Token details // Token details
let details = match details { let details = match details {
true => Some(StreamDetails { true => Some(StreamDetails {
finish_reason: FinishReason::from(generated_text.finish_reason), finish_reason: generated_text.finish_reason,
generated_tokens: generated_text.generated_tokens, generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed, seed: generated_text.seed,
}), }),
@ -597,9 +605,22 @@ async fn completions(
let span = tracing::Span::current(); let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
let stream = req.stream; let CompletionRequest {
let max_new_tokens = req.max_tokens.or(Some(100)); max_tokens,
let seed = req.seed; seed,
stop,
stream,
temperature,
..
} = req;
let max_new_tokens = max_tokens.or(Some(100));
let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature {
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};
// if suffix is present throw an error // if suffix is present throw an error
if req.suffix.is_some() { if req.suffix.is_some() {
@ -635,16 +656,16 @@ async fn completions(
inputs: prompt.to_string(), inputs: prompt.to_string(),
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: None, best_of: None,
temperature: req.temperature, temperature,
repetition_penalty: req.repetition_penalty, repetition_penalty: req.repetition_penalty,
frequency_penalty: req.frequency_penalty, frequency_penalty: req.frequency_penalty,
top_k: None, top_k: None,
top_p: req.top_p, top_p: req.top_p,
typical_p: None, typical_p: None,
do_sample: true, do_sample,
max_new_tokens, max_new_tokens,
return_full_text: None, return_full_text: None,
stop: Vec::new(), stop: stop.clone(),
truncate: None, truncate: None,
watermark: false, watermark: false,
details: true, details: true,
@ -1000,6 +1021,7 @@ async fn chat_completions(
tool_choice, tool_choice,
tool_prompt, tool_prompt,
temperature, temperature,
response_format,
.. ..
} = req; } = req;
@ -1014,6 +1036,18 @@ async fn chat_completions(
other => (true, other), other => (true, other),
}; };
// response_format and tools are mutually exclusive
if response_format.is_some() && tools.as_ref().is_some() {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Grammar and tools are mutually exclusive".to_string(),
error_type: "grammar and tools".to_string(),
}),
));
}
// extract tool grammar if present // extract tool grammar if present
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) { let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
Ok(grammar) => grammar, Ok(grammar) => grammar,
@ -1030,16 +1064,21 @@ async fn chat_completions(
} }
}; };
let grammar_with_prompt = tool_grammar // determine the appropriate arguments for apply_chat_template
let tools_grammar_prompt = tool_grammar
.as_ref() .as_ref()
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt)); .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
let typed_grammar = grammar_with_prompt let (tools_grammar_prompt, grammar) = match response_format {
.as_ref() Some(response_format) => (None, Some(response_format)),
.map(|(grammar, _)| grammar.clone()); None => (
tools_grammar_prompt.clone(),
tools_grammar_prompt.map(|(grammar, _)| grammar.clone()),
),
};
// apply chat template to flatten the request into a single input // apply chat template to flatten the request into a single input
let inputs = match infer.apply_chat_template(messages, grammar_with_prompt) { let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
Ok(inputs) => inputs, Ok(inputs) => inputs,
Err(err) => { Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::increment_counter!("tgi_request_failure", "err" => "validation");
@ -1075,7 +1114,7 @@ async fn chat_completions(
decoder_input_details: !stream, decoder_input_details: !stream,
seed, seed,
top_n_tokens: req.top_logprobs, top_n_tokens: req.top_logprobs,
grammar: typed_grammar, grammar,
}, },
}; };
@ -1320,7 +1359,8 @@ async fn tokenize(
.iter() .iter()
.zip(encoding.get_offsets()) .zip(encoding.get_offsets())
.map(|(&id, &(start, stop))| { .map(|(&id, &(start, stop))| {
let text: String = input.chars().skip(start).take(stop - start).collect(); let text: String =
String::from_utf8_lossy(&input.as_bytes()[start..stop]).to_string();
SimpleToken { SimpleToken {
id, id,
text, text,
@ -1358,34 +1398,34 @@ pub(crate) struct ComputeType(String);
/// Serving method /// Serving method
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn run( pub async fn run(
master_shard_uds_path: String,
model_info: HubModelInfo, model_info: HubModelInfo,
shard_info: ShardInfo,
compat_return_full_text: bool, compat_return_full_text: bool,
max_concurrent_requests: usize, max_concurrent_requests: usize,
max_best_of: usize, max_best_of: usize,
max_stop_sequences: usize, max_stop_sequences: usize,
max_top_n_tokens: u32, max_top_n_tokens: u32,
max_input_length: usize, max_input_tokens: usize,
max_total_tokens: usize, max_total_tokens: usize,
waiting_served_ratio: f32, waiting_served_ratio: f32,
max_batch_prefill_tokens: u32, max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
client: ShardedClient,
tokenizer: Option<Tokenizer>, tokenizer: Option<Tokenizer>,
config: Option<Config>, config: Option<Config>,
validation_workers: usize, validation_workers: usize,
addr: SocketAddr, addr: SocketAddr,
allow_origin: Option<AllowOrigin>, allow_origin: Option<AllowOrigin>,
ngrok: bool, ngrok: bool,
ngrok_authtoken: Option<String>, _ngrok_authtoken: Option<String>,
ngrok_edge: Option<String>, _ngrok_edge: Option<String>,
tokenizer_config: HubTokenizerConfig, tokenizer_config: HubTokenizerConfig,
processor_config: HubProcessorConfig,
messages_api_enabled: bool, messages_api_enabled: bool,
grammar_support: bool, grammar_support: bool,
max_client_batch_size: usize, max_client_batch_size: usize,
) -> Result<(), axum::BoxError> { ) -> Result<(), WebServerError> {
// OpenAPI documentation // OpenAPI documentation
#[derive(OpenApi)] #[derive(OpenApi)]
#[openapi( #[openapi(
@ -1455,6 +1495,141 @@ pub async fn run(
struct ApiDoc; struct ApiDoc;
// Create state // Create state
// Open connection, get model info and warmup
let (scheduler, health_ext, shard_info, max_batch_total_tokens): (
Arc<dyn Scheduler + Send + Sync>,
HealthCheck,
ShardInfo,
u32,
) = {
// Helper function to check both v2 and v3
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
match max_supported_batch_total_tokens {
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)),
);
tracing::warn!("Model does not support automatic max batch total tokens");
Ok(max_batch_total_tokens)
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
// Warn if user added his own max-batch-total-tokens as we will ignore it
if max_batch_total_tokens.is_some() {
tracing::warn!(
"`--max-batch-total-tokens` is deprecated for Flash \
Attention models."
);
tracing::warn!(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
if max_total_tokens as u32 > max_supported_batch_total_tokens {
return Err(WebServerError::NotEnoughMemory(max_total_tokens));
}
Ok(max_supported_batch_total_tokens)
}
}
};
let generation_health = Arc::new(AtomicBool::new(false));
match v3::ShardedClient::connect_uds(master_shard_uds_path.clone()).await {
Ok(mut sharded_client) => {
// server is running on v3
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.map_err(WebServerError::Cache)?;
// Get info from the shard
let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?;
// Warmup model
tracing::info!("Warming up model");
let max_batch_total_tokens = check_max_batch_total_tokens(
sharded_client
.warmup(
max_input_tokens as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_size,
)
.await
.map_err(WebServerError::Warmup)?,
)?;
let health_ext =
HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone());
let scheduler = Arc::new(SchedulerV3::new(
sharded_client,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
generation_health,
));
tracing::info!("Using scheduler V3");
(scheduler, health_ext, shard_info, max_batch_total_tokens)
}
Err(_) => {
let mut sharded_client = v2::ShardedClient::connect_uds(master_shard_uds_path)
.await
.map_err(WebServerError::Connection)?;
// server is running on v2
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.map_err(WebServerError::Cache)?;
// Get info from the shard
let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?;
// Warmup model
tracing::info!("Warming up model");
let max_batch_total_tokens = check_max_batch_total_tokens(
sharded_client
.warmup(
max_input_tokens as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_size,
)
.await
.map_err(WebServerError::Warmup)?,
)?;
let health_ext =
HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone());
let scheduler = Arc::new(SchedulerV2::new(
sharded_client,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
generation_health,
));
tracing::info!("Using scheduler V2");
(scheduler, health_ext, shard_info, max_batch_total_tokens)
}
}
};
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
let validation = Validation::new( let validation = Validation::new(
validation_workers, validation_workers,
tokenizer, tokenizer,
@ -1462,26 +1637,17 @@ pub async fn run(
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_tokens,
max_total_tokens, max_total_tokens,
grammar_support, grammar_support,
); );
let generation_health = Arc::new(AtomicBool::new(false));
let health_ext = Health::new(client.clone(), generation_health.clone());
let infer = Infer::new( let infer = Infer::new(
client, scheduler,
validation, validation,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
max_concurrent_requests, max_concurrent_requests,
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
generation_health,
tokenizer_config, tokenizer_config,
processor_config,
); );
// Duration buckets // Duration buckets
@ -1498,7 +1664,7 @@ pub async fn run(
// Input Length buckets // Input Length buckets
let input_length_matcher = Matcher::Full(String::from("tgi_request_input_length")); let input_length_matcher = Matcher::Full(String::from("tgi_request_input_length"));
let input_length_buckets: Vec<f64> = (0..100) let input_length_buckets: Vec<f64> = (0..100)
.map(|x| (max_input_length as f64 / 100.0) * (x + 1) as f64) .map(|x| (max_input_tokens as f64 / 100.0) * (x + 1) as f64)
.collect(); .collect();
// Generated tokens buckets // Generated tokens buckets
let generated_tokens_matcher = Matcher::Full(String::from("tgi_request_generated_tokens")); let generated_tokens_matcher = Matcher::Full(String::from("tgi_request_generated_tokens"));
@ -1552,7 +1718,7 @@ pub async fn run(
max_concurrent_requests, max_concurrent_requests,
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_input_length, max_input_tokens,
max_total_tokens, max_total_tokens,
waiting_served_ratio, waiting_served_ratio,
max_batch_total_tokens, max_batch_total_tokens,
@ -1566,9 +1732,9 @@ pub async fn run(
docker_label: option_env!("DOCKER_LABEL"), docker_label: option_env!("DOCKER_LABEL"),
}; };
// Define VertextApiDoc conditionally only if the "google" feature is enabled #[allow(unused_mut)] // mut is needed for conditional compilation
let doc = { let mut doc = ApiDoc::openapi();
// avoid `mut` if possible
#[cfg(feature = "google")] #[cfg(feature = "google")]
{ {
use crate::VertexInstance; use crate::VertexInstance;
@ -1578,16 +1744,46 @@ pub async fn run(
paths(vertex_compatibility), paths(vertex_compatibility),
components(schemas(VertexInstance, VertexRequest, VertexResponse)) components(schemas(VertexInstance, VertexRequest, VertexResponse))
)] )]
struct VertextApiDoc; struct VertexApiDoc;
// limiting mutability to the smallest scope necessary doc.merge(VertexApiDoc::openapi());
let mut doc = ApiDoc::openapi();
doc.merge(VertextApiDoc::openapi());
doc
} }
#[cfg(not(feature = "google"))]
ApiDoc::openapi() #[cfg(feature = "kserve")]
{
use crate::kserve::{
InferenceOutput, InferenceRequest, LiveResponse, MetadataServerResponse, OutputChunk,
ReadyResponse,
}; };
use crate::kserve::{
__path_kerve_server_metadata, __path_kserve_health_live, __path_kserve_health_ready,
__path_kserve_model_infer, __path_kserve_model_metadata,
__path_kserve_model_metadata_ready,
};
#[derive(OpenApi)]
#[openapi(
paths(
kserve_model_infer,
kserve_health_live,
kserve_health_ready,
kerve_server_metadata,
kserve_model_metadata,
kserve_model_metadata_ready,
),
components(schemas(
InferenceOutput,
InferenceRequest,
LiveResponse,
MetadataServerResponse,
OutputChunk,
ReadyResponse,
))
)]
struct KServeApiDoc;
doc.merge(KServeApiDoc::openapi());
}
// Configure Swagger UI // Configure Swagger UI
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc); let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc);
@ -1637,6 +1833,27 @@ pub async fn run(
} }
} }
#[cfg(feature = "kserve")]
{
tracing::info!("Built with `kserve` feature");
app = app
.route(
"/v2/models/:model_name/versions/:model_version/infer",
post(kserve_model_infer),
)
.route(
"/v2/models/:model_name/versions/:model_version",
get(kserve_model_metadata),
)
.route("/v2/health/ready", get(kserve_health_ready))
.route("/v2/health/live", get(kserve_health_live))
.route("/v2", get(kerve_server_metadata))
.route(
"/v2/models/:model_name/versions/:model_version/ready",
get(kserve_model_metadata_ready),
);
}
// add layers after routes // add layers after routes
app = app app = app
.layer(Extension(info)) .layer(Extension(info))
@ -1648,49 +1865,14 @@ pub async fn run(
.layer(OtelAxumLayer::default()) .layer(OtelAxumLayer::default())
.layer(cors_layer); .layer(cors_layer);
tracing::info!("Connected");
if ngrok { if ngrok {
#[cfg(feature = "ngrok")] #[cfg(feature = "ngrok")]
{ {
use ngrok::config::TunnelBuilder; panic!("ngrok feature is not functional with axum=0.7 and hyper=1, waiting on https://github.com/ngrok/ngrok-rust/pull/137/files to re-enable.");
let _ = addr;
let authtoken =
ngrok_authtoken.expect("`ngrok-authtoken` must be set when using ngrok tunneling");
let edge = ngrok_edge.expect("`ngrok-edge` must be set when using ngrok tunneling");
let tunnel = ngrok::Session::builder()
.authtoken(authtoken)
.connect()
.await
.unwrap()
.labeled_tunnel()
.label("edge", edge);
let listener = tunnel.listen().await.unwrap();
// Run prom metrics and health locally too
tokio::spawn(
axum::Server::bind(&addr)
.serve(
Router::new()
.route("/health", get(health))
.route("/metrics", get(metrics))
.layer(Extension(health_ext))
.layer(Extension(prom_handle))
.into_make_service(),
)
//Wait until all requests are finished to shut down
.with_graceful_shutdown(shutdown_signal()),
);
// Run server // Run server
axum::Server::builder(listener)
.serve(app.into_make_service())
//Wait until all requests are finished to shut down
.with_graceful_shutdown(shutdown_signal())
.await?;
} }
#[cfg(not(feature = "ngrok"))] #[cfg(not(feature = "ngrok"))]
{ {
@ -1703,11 +1885,12 @@ pub async fn run(
} }
} else { } else {
// Run server // Run server
axum::Server::bind(&addr)
.serve(app.into_make_service()) let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
// Wait until all requests are finished to shut down axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal()) .with_graceful_shutdown(shutdown_signal())
.await?; .await
.map_err(|err| WebServerError::Axum(Box::new(err)))?;
} }
Ok(()) Ok(())
} }
@ -1740,17 +1923,6 @@ async fn shutdown_signal() {
opentelemetry::global::shutdown_tracer_provider(); opentelemetry::global::shutdown_tracer_provider();
} }
impl From<i32> for FinishReason {
fn from(finish_reason: i32) -> Self {
let finish_reason = text_generation_client::FinishReason::try_from(finish_reason).unwrap();
match finish_reason {
text_generation_client::FinishReason::Length => FinishReason::Length,
text_generation_client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
text_generation_client::FinishReason::StopSequence => FinishReason::StopSequence,
}
}
}
/// Convert to Axum supported formats /// Convert to Axum supported formats
impl From<InferError> for (StatusCode, Json<ErrorResponse>) { impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
fn from(err: InferError) -> Self { fn from(err: InferError) -> Self {
@ -1783,3 +1955,19 @@ impl From<InferError> for Event {
.unwrap() .unwrap()
} }
} }
#[derive(Debug, Error)]
pub enum WebServerError {
#[error("Unable to connect to the Python model shards: {0}")]
Connection(ClientError),
#[error("Unable to clear the Python model shards cache: {0}")]
Cache(ClientError),
#[error("Unable to get the Python model shards info: {0}")]
Info(ClientError),
#[error("Unable to warmup the Python model shards: {0}")]
Warmup(ClientError),
#[error("Not enough memory to handle `max_total_tokens={0}`")]
NotEnoughMemory(usize),
#[error("Axum error: {0}")]
Axum(#[from] axum::BoxError),
}

View File

@ -1,19 +1,16 @@
use crate::config::Config;
/// Payload validation logic /// Payload validation logic
use crate::config::Config;
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{GenerateParameters, GenerateRequest, GrammarType}; use crate::{GenerateParameters, GenerateRequest, GrammarType};
use base64::{engine::general_purpose::STANDARD, Engine};
use image::{io::Reader as ImageReader, ImageFormat};
use jsonschema::{Draft, JSONSchema}; use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use serde_json::Value; use serde_json::Value;
use std::io::Cursor; use std::io::Cursor;
use text_generation_client::{ use text_generation_client::{Chunk, Image, InputChunk};
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
};
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
// use tokenizers::TruncationDirection;
use base64::{engine::general_purpose::STANDARD, Engine};
use image::{io::Reader as ImageReader, ImageFormat};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tracing::{instrument, Span}; use tracing::{instrument, Span};
@ -89,7 +86,7 @@ impl Validation {
&self, &self,
inputs: String, inputs: String,
truncate: Option<usize>, truncate: Option<usize>,
) -> Result<Option<(tokenizers::Encoding, String)>, ValidationError> { ) -> Result<Option<(tokenizers::Encoding, Vec<InputChunk>)>, ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
if let Some(sender) = &self.sender { if let Some(sender) = &self.sender {
// Create response channel // Create response channel
@ -115,7 +112,7 @@ impl Validation {
inputs: String, inputs: String,
truncate: Option<usize>, truncate: Option<usize>,
max_new_tokens: Option<u32>, max_new_tokens: Option<u32>,
) -> Result<(String, usize, u32), ValidationError> { ) -> Result<(Vec<InputChunk>, usize, u32), ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
// Create response channel // Create response channel
@ -172,13 +169,13 @@ impl Validation {
// Validate MaxNewTokens // Validate MaxNewTokens
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
input_length = input_length.saturating_sub(max_new_tokens as usize); input_length = input_length.saturating_sub(max_new_tokens as usize);
// return Err(ValidationError::MaxNewTokens(
// self.max_total_tokens - self.max_input_length,
// max_new_tokens,
// ));
} }
Ok((inputs, input_length, max_new_tokens)) Ok((
vec![Chunk::Text(inputs).into()],
input_length,
max_new_tokens,
))
} }
} }
@ -322,13 +319,13 @@ impl Validation {
// compiler and use that to build the FSM here. // compiler and use that to build the FSM here.
// Validate grammar and unpack the grammar and type for the proto message // Validate grammar and unpack the grammar and type for the proto message
let (grammar, grammar_type) = match grammar { let grammar = match grammar {
Some(grammar) => { Some(grammar) => {
// Ensure that grammar is not set if it's not supported // Ensure that grammar is not set if it's not supported
if self.disable_grammar_support { if self.disable_grammar_support {
return Err(ValidationError::Grammar); return Err(ValidationError::Grammar);
} }
match grammar { let valid_grammar = match grammar {
GrammarType::Json(json) => { GrammarType::Json(json) => {
let json = match json { let json = match json {
// if value is a string, we need to parse it again to make sure its // if value is a string, we need to parse it again to make sure its
@ -345,20 +342,20 @@ impl Validation {
.compile(&json) .compile(&json)
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
(
// Serialize json to string // Serialize json to string
ValidGrammar::Json(
serde_json::to_string(&json) serde_json::to_string(&json)
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?, .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?,
ProtoGrammarType::Json.into(),
) )
} }
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()), GrammarType::Regex(regex) => ValidGrammar::Regex(regex),
};
Some(valid_grammar)
} }
} None => None,
None => (String::new(), ProtoGrammarType::None.into()),
}; };
let parameters = NextTokenChooserParameters { let parameters = ValidParameters {
temperature, temperature,
repetition_penalty, repetition_penalty,
frequency_penalty, frequency_penalty,
@ -369,9 +366,8 @@ impl Validation {
seed, seed,
watermark, watermark,
grammar, grammar,
grammar_type,
}; };
let stopping_parameters = StoppingCriteriaParameters { let stopping_parameters = ValidStoppingParameters {
max_new_tokens, max_new_tokens,
stop_sequences, stop_sequences,
ignore_eos_token: false, ignore_eos_token: false,
@ -453,6 +449,7 @@ fn format_from_mimetype(mimetype: &str) -> Option<ImageFormat> {
_ => None, _ => None,
} }
} }
fn format_to_mimetype(format: ImageFormat) -> String { fn format_to_mimetype(format: ImageFormat) -> String {
match format { match format {
ImageFormat::Png => "image/png", ImageFormat::Png => "image/png",
@ -465,7 +462,7 @@ fn format_to_mimetype(format: ImageFormat) -> String {
.to_string() .to_string()
} }
fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> { fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
if input.starts_with("![](http://") || input.starts_with("![](https://") { if input.starts_with("![](http://") || input.starts_with("![](https://") {
let url = &input["![](".len()..input.len() - 1]; let url = &input["![](".len()..input.len() - 1];
let data = reqwest::blocking::get(url)?.bytes()?; let data = reqwest::blocking::get(url)?.bytes()?;
@ -476,9 +473,7 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
let height: usize = img.height().try_into()?; let height: usize = img.height().try_into()?;
let width: usize = img.width().try_into()?; let width: usize = img.width().try_into()?;
let mimetype = format_to_mimetype(format); let mimetype = format_to_mimetype(format);
let encoded = STANDARD.encode(data); Ok((data.to_vec(), mimetype, height, width))
let data_uri = format!("![](data:{mimetype};base64,{encoded})");
Ok((data_uri, height, width))
} else if input.starts_with("![](data:") { } else if input.starts_with("![](data:") {
// Remove ![](....) // Remove ![](....)
let content = &input["![](data:".len()..input.len() - 1]; let content = &input["![](data:".len()..input.len() - 1];
@ -495,9 +490,9 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
let data = STANDARD.decode(content["base64,".len()..].as_bytes())?; let data = STANDARD.decode(content["base64,".len()..].as_bytes())?;
let img = if let Some(format) = format_from_mimetype(mimetype) { let img = if let Some(format) = format_from_mimetype(mimetype) {
ImageReader::with_format(Cursor::new(data), format).decode()? ImageReader::with_format(Cursor::new(&data), format).decode()?
} else { } else {
ImageReader::new(Cursor::new(data)) ImageReader::new(Cursor::new(&data))
.with_guessed_format() .with_guessed_format()
.map_err(|_io_error| ValidationError::InvalidImageContent(content.to_string()))? .map_err(|_io_error| ValidationError::InvalidImageContent(content.to_string()))?
.decode()? .decode()?
@ -505,7 +500,7 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
let height: usize = img.height().try_into()?; let height: usize = img.height().try_into()?;
let width: usize = img.width().try_into()?; let width: usize = img.width().try_into()?;
Ok((input.to_string(), height, width)) Ok((data, mimetype.to_string(), height, width))
} else { } else {
Err(ValidationError::InvalidImageContent(input.to_string())) Err(ValidationError::InvalidImageContent(input.to_string()))
} }
@ -513,113 +508,110 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
/// Get input length and optionally truncate it /// Get input length and optionally truncate it
fn prepare_input( fn prepare_input(
mut inputs: String, inputs: String,
_truncate: Option<usize>, _truncate: Option<usize>,
tokenizer: &Tokenizer, tokenizer: &Tokenizer,
config: &Option<Config>, config: &Option<Config>,
) -> Result<(tokenizers::Encoding, String), ValidationError> { ) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> {
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let tokenizer_query = match config { let (tokenizer_query, input_chunks) = match config {
Some(Config::LlavaNext(config)) => { Some(Config::LlavaNext(config)) => {
let mut modified_inputs = String::with_capacity(inputs.len()); let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0; let mut start = 0;
for chunk in RE.find_iter(&inputs) { for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start(); let chunk_start = chunk.start();
let chunk_end = chunk.end(); let chunk_end = chunk.end();
if chunk_start != start { if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]); input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
tokenizer_query.push_str(&inputs[start..chunk_start]); tokenizer_query.push_str(&inputs[start..chunk_start]);
} }
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width); let slots = config.get_number_of_features(height, width);
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
tokenizer_query.push_str(&"<image>".repeat(slots)); tokenizer_query.push_str(&"<image>".repeat(slots));
modified_inputs.push_str(&image_uri);
start = chunk_end; start = chunk_end;
} }
if start != inputs.len() - 1 { if start != inputs.len() {
modified_inputs.push_str(&inputs[start..]); input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
tokenizer_query.push_str(&inputs[start..]); tokenizer_query.push_str(&inputs[start..]);
} }
inputs = modified_inputs; (tokenizer_query, input_chunks)
tokenizer_query
} }
Some(Config::Paligemma(config)) => { Some(Config::Paligemma(config)) => {
let mut modified_inputs = String::with_capacity(inputs.len()); let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0; let mut start = 0;
for chunk in RE.find_iter(&inputs) { for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start(); let chunk_start = chunk.start();
let chunk_end = chunk.end(); let chunk_end = chunk.end();
if chunk_start != start { if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]); input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
tokenizer_query.push_str(&inputs[start..chunk_start]); tokenizer_query.push_str(&inputs[start..chunk_start]);
} }
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width); let slots = config.get_number_of_features(height, width);
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
tokenizer_query.push_str(&"<image>".repeat(slots)); tokenizer_query.push_str(&"<image>".repeat(slots));
modified_inputs.push_str(&image_uri);
start = chunk_end; start = chunk_end;
} }
if start != inputs.len() - 1 { if start != inputs.len() {
modified_inputs.push_str(&inputs[start..]); input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
tokenizer_query.push_str(&inputs[start..]); tokenizer_query.push_str(&inputs[start..]);
} }
inputs = modified_inputs; (tokenizer_query, input_chunks)
tokenizer_query
} }
Some(Config::Idefics2(config)) => { Some(Config::Idefics2(config)) => {
let mut modified_inputs = String::with_capacity(inputs.len()); let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0; let mut start = 0;
for chunk in RE.find_iter(&inputs) { for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start(); let chunk_start = chunk.start();
let chunk_end = chunk.end(); let chunk_end = chunk.end();
if chunk_start != start { if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]); input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
tokenizer_query.push_str(&inputs[start..chunk_start]); tokenizer_query.push_str(&inputs[start..chunk_start]);
} }
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width); let slots = config.get_number_of_features(height, width);
tokenizer_query.push_str("<fake_token_around_image>"); tokenizer_query.push_str("<fake_token_around_image>");
tokenizer_query.push_str(&"<image>".repeat(slots)); tokenizer_query.push_str(&"<image>".repeat(slots));
tokenizer_query.push_str("<fake_token_around_image>"); tokenizer_query.push_str("<fake_token_around_image>");
modified_inputs.push_str(&image_uri); input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
start = chunk_end; start = chunk_end;
} }
if start != inputs.len() - 1 { if start != inputs.len() {
modified_inputs.push_str(&inputs[start..]); input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
tokenizer_query.push_str(&inputs[start..]); tokenizer_query.push_str(&inputs[start..]);
} }
inputs = modified_inputs; (tokenizer_query, input_chunks)
tokenizer_query
} }
Some(Config::Idefics) => { Some(Config::Idefics) => {
let mut modified_inputs = String::with_capacity(inputs.len()); let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0; let mut start = 0;
for chunk in RE.find_iter(&inputs) { for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start(); let chunk_start = chunk.start();
let chunk_end = chunk.end(); let chunk_end = chunk.end();
if chunk_start != start { if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]); input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
tokenizer_query.push_str(&inputs[start..chunk_start]); tokenizer_query.push_str(&inputs[start..chunk_start]);
} }
let (image_uri, _height, _width) = fetch_image(&inputs[chunk_start..chunk_end])?; let (data, mimetype, _height, _width) =
fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = 1; let slots = 1;
tokenizer_query.push_str(&"<image>".repeat(slots)); tokenizer_query.push_str(&"<image>".repeat(slots));
modified_inputs.push_str(&image_uri); input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
start = chunk_end; start = chunk_end;
} }
if start != inputs.len() - 1 { if start != inputs.len() {
modified_inputs.push_str(&inputs[start..]); input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
tokenizer_query.push_str(&inputs[start..]); tokenizer_query.push_str(&inputs[start..]);
} }
inputs = modified_inputs; (tokenizer_query, input_chunks)
tokenizer_query
} }
_ => inputs.clone(), _ => (inputs.clone(), vec![Chunk::Text(inputs).into()]),
}; };
// Get the number of tokens in the input // Get the number of tokens in the input
@ -627,23 +619,64 @@ fn prepare_input(
.encode(tokenizer_query, true) .encode(tokenizer_query, true)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?; .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
Ok((encoding, inputs)) Ok((encoding, input_chunks))
} }
type TokenizerRequest = ( type TokenizerRequest = (
(String, Option<usize>), (String, Option<usize>),
oneshot::Sender<Result<(tokenizers::Encoding, String), ValidationError>>, oneshot::Sender<Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError>>,
Span, Span,
); );
#[derive(Debug, Clone)]
pub(crate) enum ValidGrammar {
Json(String),
Regex(String),
}
#[derive(Debug, Clone)]
pub(crate) struct ValidParameters {
/// / exponential scaling output probability distribution
pub temperature: f32,
/// / restricting to the k highest probability elements
pub top_k: u32,
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
pub top_p: f32,
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
pub typical_p: f32,
/// / apply sampling on the logits
pub do_sample: bool,
/// / random seed for sampling
pub seed: u64,
/// / repetition penalty
pub repetition_penalty: f32,
/// / frequency penalty
pub frequency_penalty: f32,
/// / token watermarking using "A Watermark for Large Language Models"
pub watermark: bool,
/// / grammar (applied if not empty)
pub grammar: Option<ValidGrammar>,
}
#[derive(Debug, Clone)]
pub(crate) struct ValidStoppingParameters {
/// / Maximum number of generated tokens
pub max_new_tokens: u32,
/// / Optional stopping sequences
pub stop_sequences: Vec<String>,
/// / Ignore end of sequence token
/// / used for benchmarking
pub ignore_eos_token: bool,
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct ValidGenerateRequest { pub(crate) struct ValidGenerateRequest {
pub inputs: String, pub inputs: Vec<InputChunk>,
pub input_length: u32, pub input_length: u32,
pub truncate: u32, pub truncate: u32,
pub decoder_input_details: bool, pub decoder_input_details: bool,
pub parameters: NextTokenChooserParameters, pub parameters: ValidParameters,
pub stopping_parameters: StoppingCriteriaParameters, pub stopping_parameters: ValidStoppingParameters,
pub top_n_tokens: u32, pub top_n_tokens: u32,
} }
@ -714,6 +747,7 @@ pub enum ValidationError {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::config::{PaliTextConfig, Paligemma};
use crate::default_parameters; use crate::default_parameters;
use crate::tests::get_tokenizer; use crate::tests::get_tokenizer;
@ -964,4 +998,61 @@ mod tests {
assert_eq!(valid_request.top_n_tokens, 0); assert_eq!(valid_request.top_n_tokens, 0);
} }
static PIXEL_GIF: &str = "R0lGODdhAQABAIEAAP///wAAAAAAAAAAACwAAAAAAQABAAAIBAABBAQAOw==";
#[tokio::test]
async fn test_prepare_input_chunks() {
let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap();
let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2;
let max_stop_sequence = 3;
let max_top_n_tokens = 4;
let max_input_length = 5;
let max_total_tokens = 6;
let disable_grammar_support = true;
let workers = 1;
let config = Config::Paligemma(Paligemma {
text_config: PaliTextConfig {
num_image_tokens: 1,
},
});
let validation = Validation::new(
workers,
tokenizer,
Some(config),
max_best_of,
max_stop_sequence,
max_top_n_tokens,
max_input_length,
max_total_tokens,
disable_grammar_support,
);
let chunks = match validation
.tokenize(
format!("test![](data:image/gif;base64,{})", PIXEL_GIF),
None,
)
.await
{
Ok(Some((_encoding, chunks))) => chunks,
_ => panic!("Unexpected tokenization failure"),
};
assert!(
chunks
== vec![
Chunk::Text("test".to_string()).into(),
Chunk::Image(Image {
data: pixel_data.clone(),
mimetype: "image/gif".to_string()
})
.into()
],
"Failed to process images",
);
}
} }

View File

@ -1,5 +1,5 @@
[toolchain] [toolchain]
# Released on: 02 May, 2024 # Released on: June 13, 2024
# https://releases.rs/docs/1.78.0/ # https://releases.rs/docs/1.79.0/
channel = "1.78.0" channel = "1.79.0"
components = ["rustfmt", "clippy"] components = ["rustfmt", "clippy"]

View File

@ -10,18 +10,26 @@ unit-tests:
gen-server: gen-server:
# Compile protos # Compile protos
pip install grpcio-tools==1.51.1 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir
mkdir text_generation_server/pb || true mkdir text_generation_server/pb || true
python -m grpc_tools.protoc -I../proto --python_out=text_generation_server/pb \ python -m grpc_tools.protoc -I../proto/v3 --python_out=text_generation_server/pb \
--grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/generate.proto --grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/v3/generate.proto
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation_server/pb/__init__.py touch text_generation_server/pb/__init__.py
install: gen-server install-server: gen-server
pip install pip --upgrade pip install pip --upgrade
pip install -r requirements_cuda.txt pip install -r requirements_cuda.txt
pip install -e ".[bnb, accelerate, quantize, peft, outlines]" pip install -e ".[bnb, accelerate, quantize, peft, outlines]"
install: install-cuda
echo "Installed server"
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
run-dev: run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded

View File

@ -1,16 +1,12 @@
flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec
flash-attention: build-flash-attention:
# Clone flash attention if [ ! -d 'flash-attention' ]; then \
pip install -U packaging ninja --no-cache-dir pip install -U packaging ninja --no-cache-dir && \
git clone https://github.com/HazyResearch/flash-attention.git git clone https://github.com/HazyResearch/flash-attention.git; \
fi
build-flash-attention: flash-attention cd flash-attention && git fetch && git checkout $(flash_att_commit) && \
cd flash-attention && git fetch && git checkout $(flash_att_commit) MAX_JOBS=8 python setup.py build && cd csrc/layer_norm && python setup.py build && cd ../rotary && python setup.py build
cd flash-attention && python setup.py build
cd flash-attention/csrc/rotary && python setup.py build
cd flash-attention/csrc/layer_norm && python setup.py build
install-flash-attention: build-flash-attention install-flash-attention: build-flash-attention
pip uninstall flash_attn rotary_emb dropout_layer_norm -y || true cd flash-attention && git checkout $(flash_att_commit) && MAX_JOBS=8 python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install
cd flash-attention && python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install

View File

@ -1,29 +1,21 @@
flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9 flash_att_v2_commit_cuda := v2.5.9.post1
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6 flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
build-flash-attention-v2-cuda:
flash-attention-v2-cuda: pip install -U packaging wheel
# Clone flash attention pip install flash-attn==$(flash_att_v2_commit_cuda)
pip install -U packaging ninja --no-cache-dir
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2
build-flash-attention-v2-cuda: flash-attention-v2-cuda
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_cuda)
cd flash-attention-v2 && git submodule update --init --recursive
cd flash-attention-v2 && python setup.py build
install-flash-attention-v2-cuda: build-flash-attention-v2-cuda install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install echo "Flash v2 installed"
flash-attention-v2-rocm: build-flash-attention-v2-rocm:
# Clone flash attention if [ ! -d 'flash-attention-v2' ]; then \
pip install -U packaging ninja --no-cache-dir pip install -U packaging ninja --no-cache-dir && \
git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 && \
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \
build-flash-attention-v2-rocm: flash-attention-v2-rocm git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) fi
cd flash-attention-v2 && git submodule update --init --recursive
cd flash-attention-v2 && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
install-flash-attention-v2-rocm: build-flash-attention-v2-rocm install-flash-attention-v2-rocm: build-flash-attention-v2-rocm
cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install cd flash-attention-v2 && \
GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install

View File

@ -1,25 +1,23 @@
vllm-cuda: commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa
# Clone vllm commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
pip install -U ninja packaging --no-cache-dir build-vllm-cuda:
git clone https://github.com/Narsil/vllm.git vllm if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \
build-vllm-cuda: vllm-cuda git clone https://github.com/Narsil/vllm.git vllm; \
cd vllm && git fetch && git checkout b5dfc61db88a81069e45b44f7cc99bd9e62a60fa fi
cd vllm && python setup.py build cd vllm && git fetch && git checkout $(commit_cuda) && python setup.py build
install-vllm-cuda: build-vllm-cuda install-vllm-cuda: build-vllm-cuda
pip uninstall vllm -y || true cd vllm && git fetch && git checkout $(commit_cuda) && pip install -e .
cd vllm && python setup.py install
vllm-rocm: build-vllm-rocm:
# Clone vllm if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir pip install -U ninja packaging --no-cache-dir && \
git clone https://github.com/fxmarty/rocm-vllm.git vllm git clone https://github.com/fxmarty/rocm-vllm.git vllm; \
fi
build-vllm-rocm: vllm-rocm cd vllm && git fetch && git checkout $(commit_rocm) && \
cd vllm && git fetch && git checkout ca6913b3c2ffacdcb7d15e914dc34adbc6c89479 PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
cd vllm && PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install
install-vllm-rocm: build-vllm-rocm install-vllm-rocm: build-vllm-rocm
pip uninstall vllm -y || true cd vllm && git fetch && git checkout $(commit_rocm) && \
cd vllm && python setup.py install PYTORCH_ROCM_ARCH="gfx90a;gfx942" pip install -e .

20
server/marlin/COPYRIGHT Normal file
View File

@ -0,0 +1,20 @@
These kernels were vendored from VLLM. The Marlin kernels were developed
by Elias Frantar and extended by Neural Magic.
---
Copyright (C) Marlin.2024 Elias Frantar
Modified by Neural Magic
Copyright 2024 The vLLM team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -0,0 +1,44 @@
import torch
def gptq_marlin_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
g_idx: torch.Tensor,
perm: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
size_m: int,
size_n: int,
size_k: int,
is_k_full: bool,
) -> torch.Tensor:
"""
Matrix multiplication using Marlin kernels. This is an extension of
`marlin_gemm` that supports converted GPTQ kernels.
"""
...
def gptq_marlin_repack(
b_q_weight: torch.Tensor,
perm: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
) -> torch.Tensor:
"""Repack GPTQ parameters for Marlin kernels."""
...
def marlin_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
workspace: torch.Tensor,
size_m: int,
size_n: int,
size_k: int,
) -> torch.Tensor:
"""
Matrix multiplication using Marlin kernels.
"""
...

View File

@ -0,0 +1,11 @@
#include <torch/extension.h>
#include "ext.hh"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gptq_marlin_gemm", &gptq_marlin_gemm,
"Marlin gemm with GPTQ compatibility");
m.def("gptq_marlin_repack", &gptq_marlin_repack,
"Repack GPTQ parameters for Marlin");
m.def("marlin_gemm", &marlin_gemm, "Marlin gemm");
}

Some files were not shown because too many files have changed in this diff Show More