Merge branch 'main' into lora-internal

This commit is contained in:
drbh 2024-06-25 12:23:04 -04:00 committed by GitHub
commit 59575fe62a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
67 changed files with 1605 additions and 216 deletions

View File

@ -156,6 +156,8 @@ jobs:
needs: build-and-push needs: build-and-push
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"] runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest' if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
env:
PYTEST_FLAGS: ${{ github.ref == 'refs/heads/main' && '--release' || '' }}
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
@ -178,6 +180,6 @@ jobs:
export DOCKER_VOLUME=/mnt/cache export DOCKER_VOLUME=/mnt/cache
export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }} export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}
export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }} export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }}
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} export HF_TOKEN=${{ secrets.HF_TOKEN }}
echo $DOCKER_IMAGE echo $DOCKER_IMAGE
pytest -s -vv integration-tests pytest -s -vv integration-tests ${PYTEST_FLAGS}

View File

@ -22,5 +22,5 @@ jobs:
- name: Run tests - name: Run tests
run: | run: |
pip install pytest pytest-asyncio pip install pytest pytest-asyncio
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} export HF_TOKEN=${{ secrets.HF_TOKEN }}
make python-client-tests make python-client-tests

View File

@ -37,5 +37,5 @@ jobs:
export DOCKER_VOLUME=/mnt/cache export DOCKER_VOLUME=/mnt/cache
export DOCKER_IMAGE=${{ inputs.docker_image }} export DOCKER_IMAGE=${{ inputs.docker_image }}
export DOCKER_DEVICES=${{ inputs.docker_devices }} export DOCKER_DEVICES=${{ inputs.docker_devices }}
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} export HF_TOKEN=${{ secrets.HF_TOKEN }}
pytest -s -vv integration-tests pytest -s -vv integration-tests

View File

@ -11,66 +11,24 @@ on:
- 'main' - 'main'
jobs: jobs:
start-runner:
name: Start self-hosted EC2 runner
runs-on: ubuntu-latest
env:
AWS_REGION: eu-central-1
EC2_AMI_ID: ami-0ab09c07cfd194259
EC2_INSTANCE_TYPE: g5.12xlarge
EC2_SUBNET_ID: subnet-988fd9f2,subnet-6f56db13,subnet-6a039326
EC2_SECURITY_GROUP: sg-072f92ae3082936c6
outputs:
label: ${{ steps.start-ec2-runner.outputs.label }}
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
steps:
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v1
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
- name: Start EC2 runner
id: start-ec2-runner
uses: philschmid/philschmid-ec2-github-runner@main
with:
mode: start
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
ec2-image-id: ${{ env.EC2_AMI_ID }}
ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
subnet-id: ${{ env.EC2_SUBNET_ID }}
security-group-id: ${{ env.EC2_SECURITY_GROUP }}
aws-resource-tags: > # optional, requires additional permissions
[
{"Key": "Name", "Value": "ec2-tgi-github-runner"},
{"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
]
load-tests: load-tests:
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true cancel-in-progress: true
needs: start-runner # required to start the main job when the runner is ready runs-on: [self-hosted, nvidia-gpu , multi-gpu, 4-a10, ci]
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
env: env:
DOCKER_VOLUME: /cache DOCKER_VOLUME: /cache
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Prepare disks
run: |
sudo mkfs -t ext4 /dev/nvme1n1
sudo mkdir ${{ env.DOCKER_VOLUME }}
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
- name: Install k6 - name: Install k6
run: | run: |
curl https://github.com/grafana/k6/releases/download/v0.44.0/k6-v0.44.0-linux-amd64.tar.gz -L | tar xvz --strip-components 1 curl https://github.com/grafana/k6/releases/download/v0.44.0/k6-v0.44.0-linux-amd64.tar.gz -L | tar xvz --strip-components 1
- name: Start starcoder - name: Start starcoder
run: | run: |
docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v ${{ env.DOCKER_VOLUME }}:/data -e HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768 docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v /mnt/cache:/data -e HF_TOKEN=${{ secrets.HF_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768
sleep 10 sleep 10
wget --timeout 10 --retry-on-http-error --waitretry=1 --tries=240 http://localhost:3000/health wget --timeout 10 --retry-on-http-error --waitretry=1 --tries=240 http://localhost:3000/health
@ -82,27 +40,3 @@ jobs:
if: ${{ always() }} if: ${{ always() }}
run: | run: |
docker stop tgi-starcoder || true docker stop tgi-starcoder || true
stop-runner:
name: Stop self-hosted EC2 runner
needs:
- start-runner
- load-tests
runs-on: ubuntu-latest
env:
AWS_REGION: eu-central-1
if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
steps:
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v1
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
- name: Stop EC2 runner
uses: philschmid/philschmid-ec2-github-runner@main
with:
mode: stop
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
label: ${{ needs.start-runner.outputs.label }}
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}

View File

@ -72,7 +72,7 @@ jobs:
- name: Run server tests - name: Run server tests
run: | run: |
pip install pytest pip install pytest
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HF_TOKEN }} export HF_TOKEN=${{ secrets.HF_TOKEN }}
pytest -s -vv server/tests pytest -s -vv server/tests
- name: Pre-commit checks - name: Pre-commit checks
run: | run: |

View File

@ -1,3 +1,5 @@
ARG PLATFORM=xpu
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
WORKDIR /usr/src WORKDIR /usr/src
@ -37,7 +39,8 @@ RUN cargo build --profile release-opt
# Text Generation Inference base image for Intel # Text Generation Inference base image for Intel
FROM intel/intel-extension-for-pytorch:2.1.30-xpu as base
FROM intel/intel-extension-for-pytorch:2.1.30-xpu as xpu
USER root USER root
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it # libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
@ -49,7 +52,7 @@ RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dea
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build pciutils
# Text Generation Inference base env # Text Generation Inference base env
ENV HUGGINGFACE_HUB_CACHE=/data \ ENV HUGGINGFACE_HUB_CACHE=/data \
@ -59,7 +62,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
WORKDIR /usr/src WORKDIR /usr/src
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b group_rope origin/dev/gqa_rope RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed
# Install server # Install server
COPY proto proto COPY proto proto
@ -89,8 +92,84 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
# Install launcher # Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
# Final image
FROM base
# Text Generation Inference base image for Intel-cpu
FROM ubuntu:22.04 as cpu
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
curl \
ca-certificates \
make \
g++ \
git \
wget \
cmake
ENV HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80
ARG MAMBA_VERSION=23.1.0-1
ARG PYTHON_VERSION='3.10.10'
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba
# translating Docker's TARGETPLATFORM into mamba arches
RUN case ${TARGETPLATFORM} in \
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
*) MAMBA_ARCH=x86_64 ;; \
esac && \
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
RUN chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \
rm ~/mambaforge.sh
RUN conda install -c conda-forge gperftools mkl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
WORKDIR /usr/src
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so:/opt/conda/lib/libiomp5.so
ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib
ENV KMP_BLOCKTIME=1
ENV KMP_TPAUSE=0
ENV KMP_FORKJOIN_BARRIER_PATTERN=dist,dist
ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist
ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist
# Install server
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
RUN cd server && \
make gen-server && \
pip install -r requirements_intel.txt && \
pip install ".[accelerate, peft, outlines]" --no-cache-dir
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
FROM ${PLATFORM} as final
ENTRYPOINT ["text-generation-launcher"] ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"] CMD ["--json-output"]

View File

@ -105,14 +105,14 @@ The Swagger UI is also available at: [https://huggingface.github.io/text-generat
### Using a private or gated model ### Using a private or gated model
You have the option to utilize the `HUGGING_FACE_HUB_TOKEN` environment variable for configuring the token employed by You have the option to utilize the `HF_TOKEN` environment variable for configuring the token employed by
`text-generation-inference`. This allows you to gain access to protected resources. `text-generation-inference`. This allows you to gain access to protected resources.
For example, if you want to serve the gated Llama V2 model variants: For example, if you want to serve the gated Llama V2 model variants:
1. Go to https://huggingface.co/settings/tokens 1. Go to https://huggingface.co/settings/tokens
2. Copy your cli READ token 2. Copy your cli READ token
3. Export `HUGGING_FACE_HUB_TOKEN=<your cli READ token>` 3. Export `HF_TOKEN=<your cli READ token>`
or with Docker: or with Docker:
@ -121,7 +121,7 @@ model=meta-llama/Llama-2-7b-chat-hf
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
token=<your cli READ token> token=<your cli READ token>
docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model
``` ```
### A note on Shared Memory (shm) ### A note on Shared Memory (shm)
@ -153,7 +153,8 @@ this will impact performance.
### Distributed Tracing ### Distributed Tracing
`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature `text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature
by setting the address to an OTLP collector with the `--otlp-endpoint` argument. by setting the address to an OTLP collector with the `--otlp-endpoint` argument. The default service name can be
overridden with the `--otlp-service-name` argument
### Architecture ### Architecture

View File

@ -147,7 +147,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing::info!("Downloading tokenizer"); tracing::info!("Downloading tokenizer");
// Parse Huggingface hub token // Parse Huggingface hub token
let auth_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok(); let auth_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
// Download and instantiate tokenizer // Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime // We need to download it outside of the Tokio runtime

View File

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator, ConfigDict
from typing import Optional, List, Union, Any from typing import Optional, List, Union, Any
from text_generation.errors import ValidationError from text_generation.errors import ValidationError
@ -452,5 +452,9 @@ class StreamResponse(BaseModel):
# Inference API currently deployed model # Inference API currently deployed model
class DeployedModel(BaseModel): class DeployedModel(BaseModel):
# Disable warning for use of `model_` prefix in `model_id`. Be mindful about adding members
# with model_ prefixes, since this disables guardrails for colliding fields:
# https://github.com/pydantic/pydantic/issues/9177
model_config = ConfigDict(protected_namespaces=())
model_id: str model_id: str
sha: str sha: str

View File

@ -70,6 +70,8 @@ Options:
[env: JSON_OUTPUT=] [env: JSON_OUTPUT=]
--otlp-endpoint <OTLP_ENDPOINT> --otlp-endpoint <OTLP_ENDPOINT>
[env: OTLP_ENDPOINT=] [env: OTLP_ENDPOINT=]
--otlp-service-name <OTLP_SERVICE_NAME>
[env: OTLP_SERVICE_NAME=]
--cors-allow-origin <CORS_ALLOW_ORIGIN> --cors-allow-origin <CORS_ALLOW_ORIGIN>
[env: CORS_ALLOW_ORIGIN=] [env: CORS_ALLOW_ORIGIN=]
--ngrok --ngrok
@ -138,6 +140,8 @@ Serve's command line parameters on the TGI repository are these:
│ --logger-level TEXT [default: INFO] │ │ --logger-level TEXT [default: INFO] │
│ --json-output --no-json-output [default: no-json-output] │ │ --json-output --no-json-output [default: no-json-output] │
│ --otlp-endpoint TEXT [default: None] │ │ --otlp-endpoint TEXT [default: None] │
│ --otlp-service-name TEXT [default: │
│ text-generation-inference...│
│ --help Show this message and exit. │ │ --help Show this message and exit. │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯
``` ```

View File

@ -2,13 +2,13 @@
If the model you wish to serve is behind gated access or the model repository on Hugging Face Hub is private, and you have access to the model, you can provide your Hugging Face Hub access token. You can generate and copy a read token from [Hugging Face Hub tokens page](https://huggingface.co/settings/tokens) If the model you wish to serve is behind gated access or the model repository on Hugging Face Hub is private, and you have access to the model, you can provide your Hugging Face Hub access token. You can generate and copy a read token from [Hugging Face Hub tokens page](https://huggingface.co/settings/tokens)
If you're using the CLI, set the `HUGGING_FACE_HUB_TOKEN` environment variable. For example: If you're using the CLI, set the `HF_TOKEN` environment variable. For example:
``` ```
export HUGGING_FACE_HUB_TOKEN=<YOUR READ TOKEN> export HF_TOKEN=<YOUR READ TOKEN>
``` ```
If you would like to do it through Docker, you can provide your token by specifying `HUGGING_FACE_HUB_TOKEN` as shown below. If you would like to do it through Docker, you can provide your token by specifying `HF_TOKEN` as shown below.
```bash ```bash
model=meta-llama/Llama-2-7b-chat-hf model=meta-llama/Llama-2-7b-chat-hf
@ -17,7 +17,7 @@ token=<your READ token>
docker run --gpus all \ docker run --gpus all \
--shm-size 1g \ --shm-size 1g \
-e HUGGING_FACE_HUB_TOKEN=$token \ -e HF_TOKEN=$token \
-p 8080:80 \ -p 8080:80 \
-v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.4 \ -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.4 \
--model-id $model --model-id $model

View File

@ -336,6 +336,13 @@ Options:
--otlp-endpoint <OTLP_ENDPOINT> --otlp-endpoint <OTLP_ENDPOINT>
[env: OTLP_ENDPOINT=] [env: OTLP_ENDPOINT=]
```
## OTLP_SERVICE_NAME
```shell
--otlp-service-name <OTLP_SERVICE_NAME>
[env: OTLP_SERVICE_NAME=]
[default: text-generation-inference.router]
``` ```
## CORS_ALLOW_ORIGIN ## CORS_ALLOW_ORIGIN
```shell ```shell

View File

@ -1,42 +1,62 @@
import sys
import subprocess
import contextlib
import pytest
import asyncio import asyncio
import os import contextlib
import docker
import json import json
import math import math
import os
import random
import re
import shutil import shutil
import subprocess
import sys
import tempfile import tempfile
import time import time
import random from typing import Dict, List, Optional
from docker.errors import NotFound import docker
from typing import Optional, List, Dict import pytest
from syrupy.extensions.json import JSONSnapshotExtension
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
from docker.errors import NotFound
from syrupy.extensions.json import JSONSnapshotExtension
from text_generation import AsyncClient from text_generation import AsyncClient
from text_generation.types import ( from text_generation.types import (
Response,
Details,
InputToken,
Token,
BestOfSequence, BestOfSequence,
Grammar,
ChatComplete, ChatComplete,
ChatCompletionChunk, ChatCompletionChunk,
ChatCompletionComplete, ChatCompletionComplete,
Completion, Completion,
Details,
Grammar,
InputToken,
Response,
Token,
) )
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None) HF_TOKEN = os.getenv("HF_TOKEN", None)
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data") DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
DOCKER_DEVICES = os.getenv("DOCKER_DEVICES") DOCKER_DEVICES = os.getenv("DOCKER_DEVICES")
def pytest_addoption(parser):
parser.addoption(
"--release", action="store_true", default=False, help="run release tests"
)
def pytest_configure(config):
config.addinivalue_line("markers", "release: mark test as a release-only test")
def pytest_collection_modifyitems(config, items):
if config.getoption("--release"):
# --release given in cli: do not skip release tests
return
skip_release = pytest.mark.skip(reason="need --release option to run")
for item in items:
if "release" in item.keywords:
item.add_marker(skip_release)
class ResponseComparator(JSONSnapshotExtension): class ResponseComparator(JSONSnapshotExtension):
rtol = 0.2 rtol = 0.2
ignore_logprob = False ignore_logprob = False
@ -447,8 +467,8 @@ def launcher(event_loop):
if not use_flash_attention: if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false" env["USE_FLASH_ATTENTION"] = "false"
if HUGGING_FACE_HUB_TOKEN is not None: if HF_TOKEN is not None:
env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN env["HF_TOKEN"] = HF_TOKEN
volumes = [] volumes = []
if DOCKER_VOLUME: if DOCKER_VOLUME:

View File

@ -13,6 +13,7 @@ async def bloom_560(bloom_560_handle):
return bloom_560_handle.client return bloom_560_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m(bloom_560, response_snapshot): async def test_bloom_560m(bloom_560, response_snapshot):
response = await bloom_560.generate( response = await bloom_560.generate(
@ -27,6 +28,7 @@ async def test_bloom_560m(bloom_560, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m_all_params(bloom_560, response_snapshot): async def test_bloom_560m_all_params(bloom_560, response_snapshot):
response = await bloom_560.generate( response = await bloom_560.generate(
@ -49,6 +51,7 @@ async def test_bloom_560m_all_params(bloom_560, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m_load(bloom_560, generate_load, response_snapshot): async def test_bloom_560m_load(bloom_560, generate_load, response_snapshot):
responses = await generate_load( responses = await generate_load(

View File

@ -13,6 +13,7 @@ async def bloom_560m_sharded(bloom_560m_sharded_handle):
return bloom_560m_sharded_handle.client return bloom_560m_sharded_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot): async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
response = await bloom_560m_sharded.generate( response = await bloom_560m_sharded.generate(
@ -27,6 +28,7 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m_sharded_load( async def test_bloom_560m_sharded_load(
bloom_560m_sharded, generate_load, response_snapshot bloom_560m_sharded, generate_load, response_snapshot

View File

@ -26,6 +26,7 @@ async def flash_llama_completion(flash_llama_completion_handle):
# method for it. Instead, we use the `requests` library to make the HTTP request directly. # method for it. Instead, we use the `requests` library to make the HTTP request directly.
@pytest.mark.release
def test_flash_llama_completion_single_prompt( def test_flash_llama_completion_single_prompt(
flash_llama_completion, response_snapshot flash_llama_completion, response_snapshot
): ):
@ -46,6 +47,7 @@ def test_flash_llama_completion_single_prompt(
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot): def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
response = requests.post( response = requests.post(
f"{flash_llama_completion.base_url}/v1/completions", f"{flash_llama_completion.base_url}/v1/completions",
@ -68,6 +70,7 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
async def test_flash_llama_completion_many_prompts_stream( async def test_flash_llama_completion_many_prompts_stream(
flash_llama_completion, response_snapshot flash_llama_completion, response_snapshot
): ):

View File

@ -17,6 +17,7 @@ async def flash_llama_awq(flash_llama_awq_handle):
return flash_llama_awq_handle.client return flash_llama_awq_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_llama_awq(flash_llama_awq, response_snapshot): async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
response = await flash_llama_awq.generate( response = await flash_llama_awq.generate(
@ -31,6 +32,7 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot): async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
response = await flash_llama_awq.generate( response = await flash_llama_awq.generate(
@ -52,6 +54,7 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot): async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot):
responses = await generate_load( responses = await generate_load(

View File

@ -17,6 +17,7 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
return flash_llama_awq_handle_sharded.client return flash_llama_awq_handle_sharded.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot): async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
response = await flash_llama_awq_sharded.generate( response = await flash_llama_awq_sharded.generate(
@ -31,6 +32,7 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_llama_awq_load_sharded( async def test_flash_llama_awq_load_sharded(
flash_llama_awq_sharded, generate_load, response_snapshot flash_llama_awq_sharded, generate_load, response_snapshot

View File

@ -13,6 +13,7 @@ async def flash_falcon(flash_falcon_handle):
return flash_falcon_handle.client return flash_falcon_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_falcon(flash_falcon, response_snapshot): async def test_flash_falcon(flash_falcon, response_snapshot):
@ -26,6 +27,7 @@ async def test_flash_falcon(flash_falcon, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_falcon_all_params(flash_falcon, response_snapshot): async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
@ -49,6 +51,7 @@ async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_falcon_load(flash_falcon, generate_load, response_snapshot): async def test_flash_falcon_load(flash_falcon, generate_load, response_snapshot):

View File

@ -13,6 +13,7 @@ async def flash_gemma(flash_gemma_handle):
return flash_gemma_handle.client return flash_gemma_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma(flash_gemma, response_snapshot): async def test_flash_gemma(flash_gemma, response_snapshot):
@ -24,6 +25,7 @@ async def test_flash_gemma(flash_gemma, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma_all_params(flash_gemma, response_snapshot): async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
@ -47,6 +49,7 @@ async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot): async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot):

View File

@ -13,6 +13,7 @@ async def flash_gemma_gptq(flash_gemma_gptq_handle):
return flash_gemma_gptq_handle.client return flash_gemma_gptq_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapshot): async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapshot):
@ -24,6 +25,7 @@ async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapsh
assert response == ignore_logprob_response_snapshot assert response == ignore_logprob_response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma_gptq_all_params( async def test_flash_gemma_gptq_all_params(
@ -49,6 +51,7 @@ async def test_flash_gemma_gptq_all_params(
assert response == ignore_logprob_response_snapshot assert response == ignore_logprob_response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma_gptq_load( async def test_flash_gemma_gptq_load(

View File

@ -13,6 +13,7 @@ async def flash_gpt2(flash_gpt2_handle):
return flash_gpt2_handle.client return flash_gpt2_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_gpt2(flash_gpt2, response_snapshot): async def test_flash_gpt2(flash_gpt2, response_snapshot):
response = await flash_gpt2.generate( response = await flash_gpt2.generate(
@ -25,6 +26,7 @@ async def test_flash_gpt2(flash_gpt2, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_gpt2_load(flash_gpt2, generate_load, response_snapshot): async def test_flash_gpt2_load(flash_gpt2, generate_load, response_snapshot):
responses = await generate_load( responses = await generate_load(

View File

@ -21,6 +21,7 @@ async def flash_llama_exl2(flash_llama_exl2_handle):
return flash_llama_exl2_handle.client return flash_llama_exl2_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot): async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
@ -32,6 +33,7 @@ async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapsh
assert response == ignore_logprob_response_snapshot assert response == ignore_logprob_response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_exl2_all_params( async def test_flash_llama_exl2_all_params(
@ -58,6 +60,7 @@ async def test_flash_llama_exl2_all_params(
assert response == ignore_logprob_response_snapshot assert response == ignore_logprob_response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_exl2_load( async def test_flash_llama_exl2_load(

View File

@ -13,6 +13,7 @@ async def flash_llama_gptq(flash_llama_gptq_handle):
return flash_llama_gptq_handle.client return flash_llama_gptq_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot): async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):
@ -24,6 +25,7 @@ async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot): async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
@ -46,6 +48,7 @@ async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_gptq_load( async def test_flash_llama_gptq_load(

View File

@ -15,6 +15,7 @@ async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle):
return flash_llama_gptq_marlin_handle.client return flash_llama_gptq_marlin_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot): async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot):
@ -26,6 +27,7 @@ async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapsho
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_gptq_marlin_all_params( async def test_flash_llama_gptq_marlin_all_params(
@ -50,6 +52,7 @@ async def test_flash_llama_gptq_marlin_all_params(
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_gptq_marlin_load( async def test_flash_llama_gptq_marlin_load(

View File

@ -15,6 +15,7 @@ async def flash_llama_marlin(flash_llama_marlin_handle):
return flash_llama_marlin_handle.client return flash_llama_marlin_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot): async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
@ -26,6 +27,7 @@ async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapshot): async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapshot):
@ -48,6 +50,7 @@ async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapsh
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_marlin_load( async def test_flash_llama_marlin_load(

View File

@ -13,6 +13,7 @@ async def flash_neox(flash_neox_handle):
return flash_neox_handle.client return flash_neox_handle.client
@pytest.mark.release
@pytest.mark.skip @pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_neox(flash_neox, response_snapshot): async def test_flash_neox(flash_neox, response_snapshot):
@ -26,6 +27,7 @@ async def test_flash_neox(flash_neox, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.skip @pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_neox_load(flash_neox, generate_load, response_snapshot): async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):

View File

@ -13,6 +13,7 @@ async def flash_neox_sharded(flash_neox_sharded_handle):
return flash_neox_sharded_handle.client return flash_neox_sharded_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_neox(flash_neox_sharded, response_snapshot): async def test_flash_neox(flash_neox_sharded, response_snapshot):
response = await flash_neox_sharded.generate( response = await flash_neox_sharded.generate(
@ -25,6 +26,7 @@ async def test_flash_neox(flash_neox_sharded, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_neox_load(flash_neox_sharded, generate_load, response_snapshot): async def test_flash_neox_load(flash_neox_sharded, generate_load, response_snapshot):
responses = await generate_load( responses = await generate_load(

View File

@ -34,6 +34,7 @@ def get_cow_beach():
return f"data:image/png;base64,{encoded_string.decode('utf-8')}" return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
@ -45,6 +46,7 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot): async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot):

View File

@ -13,6 +13,7 @@ async def flash_phi(flash_phi_handle):
return flash_phi_handle.client return flash_phi_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_phi(flash_phi, response_snapshot): async def test_flash_phi(flash_phi, response_snapshot):
response = await flash_phi.generate( response = await flash_phi.generate(
@ -24,6 +25,7 @@ async def test_flash_phi(flash_phi, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_phi_all_params(flash_phi, response_snapshot): async def test_flash_phi_all_params(flash_phi, response_snapshot):
response = await flash_phi.generate( response = await flash_phi.generate(
@ -47,6 +49,7 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot): async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4) responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4)

View File

@ -13,6 +13,7 @@ async def flash_qwen2(flash_qwen2_handle):
return flash_qwen2_handle.client return flash_qwen2_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_qwen2(flash_qwen2, response_snapshot): async def test_flash_qwen2(flash_qwen2, response_snapshot):
response = await flash_qwen2.generate( response = await flash_qwen2.generate(
@ -24,6 +25,7 @@ async def test_flash_qwen2(flash_qwen2, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot): async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot):
response = await flash_qwen2.generate( response = await flash_qwen2.generate(
@ -46,6 +48,7 @@ async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_qwen2_load(flash_qwen2, generate_load, response_snapshot): async def test_flash_qwen2_load(flash_qwen2, generate_load, response_snapshot):
responses = await generate_load(flash_qwen2, "Test request", max_new_tokens=10, n=4) responses = await generate_load(flash_qwen2, "Test request", max_new_tokens=10, n=4)

View File

@ -13,6 +13,7 @@ async def flash_santacoder(flash_santacoder_handle):
return flash_santacoder_handle.client return flash_santacoder_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_santacoder(flash_santacoder, response_snapshot): async def test_flash_santacoder(flash_santacoder, response_snapshot):
response = await flash_santacoder.generate( response = await flash_santacoder.generate(
@ -23,6 +24,7 @@ async def test_flash_santacoder(flash_santacoder, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_santacoder_load( async def test_flash_santacoder_load(
flash_santacoder, generate_load, response_snapshot flash_santacoder, generate_load, response_snapshot

View File

@ -13,6 +13,7 @@ async def flash_starcoder(flash_starcoder_handle):
return flash_starcoder_handle.client return flash_starcoder_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_starcoder(flash_starcoder, response_snapshot): async def test_flash_starcoder(flash_starcoder, response_snapshot):
@ -24,6 +25,7 @@ async def test_flash_starcoder(flash_starcoder, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot): async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot):
@ -40,6 +42,7 @@ async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_starcoder_load(flash_starcoder, generate_load, response_snapshot): async def test_flash_starcoder_load(flash_starcoder, generate_load, response_snapshot):

View File

@ -13,6 +13,7 @@ async def flash_starcoder2(flash_starcoder2_handle):
return flash_starcoder2_handle.client return flash_starcoder2_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_starcoder2(flash_starcoder2, response_snapshot): async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
@ -24,6 +25,7 @@ async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot): async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot):
@ -40,6 +42,7 @@ async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapsh
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_starcoder2_load( async def test_flash_starcoder2_load(

View File

@ -13,6 +13,7 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
return flash_starcoder_gptq_handle.client return flash_starcoder_gptq_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot): async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot):
response = await flash_starcoder_gptq.generate( response = await flash_starcoder_gptq.generate(
@ -24,6 +25,7 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap
assert response == generous_response_snapshot assert response == generous_response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_starcoder_gptq_default_params( async def test_flash_starcoder_gptq_default_params(
flash_starcoder_gptq, generous_response_snapshot flash_starcoder_gptq, generous_response_snapshot
@ -40,6 +42,7 @@ async def test_flash_starcoder_gptq_default_params(
assert response == generous_response_snapshot assert response == generous_response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_starcoder_gptq_load( async def test_flash_starcoder_gptq_load(
flash_starcoder_gptq, generate_load, generous_response_snapshot flash_starcoder_gptq, generate_load, generous_response_snapshot

View File

@ -21,6 +21,7 @@ async def non_flash_llama_grammar(non_flash_llama_grammar_handle):
return non_flash_llama_grammar_handle.client return non_flash_llama_grammar_handle.client
@pytest.mark.release
@pytest.mark.skip @pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_non_flash_llama_grammar_json(non_flash_llama_grammar, response_snapshot): async def test_non_flash_llama_grammar_json(non_flash_llama_grammar, response_snapshot):

View File

@ -22,6 +22,7 @@ async def llama_grammar(llama_grammar_handle):
return llama_grammar_handle.client return llama_grammar_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot): async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
@ -62,6 +63,7 @@ async def test_grammar_response_format_llama_json(llama_grammar, response_snapsh
assert chat_completion == response_snapshot assert chat_completion == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_grammar_response_format_llama_error_if_tools_not_installed( async def test_grammar_response_format_llama_error_if_tools_not_installed(
llama_grammar, llama_grammar,

View File

@ -45,6 +45,7 @@ async def test_idefics(idefics, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_idefics_two_images(idefics, response_snapshot): async def test_idefics_two_images(idefics, response_snapshot):
@ -60,6 +61,7 @@ async def test_idefics_two_images(idefics, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_idefics_load(idefics, generate_load, response_snapshot): async def test_idefics_load(idefics, generate_load, response_snapshot):
chicken = get_chicken() chicken = get_chicken()

View File

@ -26,6 +26,7 @@ async def flash_llava_next(flash_llava_next_handle):
return flash_llava_next_handle.client return flash_llava_next_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot): async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
@ -41,6 +42,7 @@ async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot): async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
@ -64,6 +66,7 @@ async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llava_next_load( async def test_flash_llava_next_load(

View File

@ -13,6 +13,7 @@ async def fused_kernel_mamba(fused_kernel_mamba_handle):
return fused_kernel_mamba_handle.client return fused_kernel_mamba_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mamba(fused_kernel_mamba, response_snapshot): async def test_mamba(fused_kernel_mamba, response_snapshot):
response = await fused_kernel_mamba.generate( response = await fused_kernel_mamba.generate(
@ -24,6 +25,7 @@ async def test_mamba(fused_kernel_mamba, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
response = await fused_kernel_mamba.generate( response = await fused_kernel_mamba.generate(
@ -50,6 +52,7 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mamba_load( async def test_mamba_load(
fused_kernel_mamba, generate_load, generous_response_snapshot fused_kernel_mamba, generate_load, generous_response_snapshot

View File

@ -13,6 +13,7 @@ async def mpt_sharded(mpt_sharded_handle):
return mpt_sharded_handle.client return mpt_sharded_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mpt(mpt_sharded, response_snapshot): async def test_mpt(mpt_sharded, response_snapshot):
response = await mpt_sharded.generate( response = await mpt_sharded.generate(
@ -29,6 +30,7 @@ async def test_mpt(mpt_sharded, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mpt_load(mpt_sharded, generate_load, response_snapshot): async def test_mpt_load(mpt_sharded, generate_load, response_snapshot):
responses = await generate_load( responses = await generate_load(

View File

@ -13,6 +13,7 @@ async def mt0_base(mt0_base_handle):
return mt0_base_handle.client return mt0_base_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mt0_base(mt0_base, response_snapshot): async def test_mt0_base(mt0_base, response_snapshot):
response = await mt0_base.generate( response = await mt0_base.generate(
@ -27,6 +28,7 @@ async def test_mt0_base(mt0_base, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mt0_base_all_params(mt0_base, response_snapshot): async def test_mt0_base_all_params(mt0_base, response_snapshot):
response = await mt0_base.generate( response = await mt0_base.generate(
@ -49,6 +51,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mt0_base_load(mt0_base, generate_load, response_snapshot): async def test_mt0_base_load(mt0_base, generate_load, response_snapshot):
responses = await generate_load( responses = await generate_load(

View File

@ -15,6 +15,7 @@ async def neox(neox_handle):
return neox_handle.client return neox_handle.client
@pytest.mark.release
@pytest.mark.skip @pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_neox(neox, response_snapshot): async def test_neox(neox, response_snapshot):
@ -28,6 +29,7 @@ async def test_neox(neox, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.skip @pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_neox_load(neox, generate_load, response_snapshot): async def test_neox_load(neox, generate_load, response_snapshot):

View File

@ -15,6 +15,7 @@ async def neox_sharded(neox_sharded_handle):
return neox_sharded_handle.client return neox_sharded_handle.client
@pytest.mark.release
@pytest.mark.skip @pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_neox(neox_sharded, response_snapshot): async def test_neox(neox_sharded, response_snapshot):
@ -28,6 +29,7 @@ async def test_neox(neox_sharded, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.skip @pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_neox_load(neox_sharded, generate_load, response_snapshot): async def test_neox_load(neox_sharded, generate_load, response_snapshot):

View File

@ -13,6 +13,7 @@ async def t5_sharded(t5_sharded_handle):
return t5_sharded_handle.client return t5_sharded_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_t5_sharded(t5_sharded, response_snapshot): async def test_t5_sharded(t5_sharded, response_snapshot):
response = await t5_sharded.generate( response = await t5_sharded.generate(
@ -24,6 +25,7 @@ async def test_t5_sharded(t5_sharded, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_t5_sharded_load(t5_sharded, generate_load, response_snapshot): async def test_t5_sharded_load(t5_sharded, generate_load, response_snapshot):
responses = await generate_load( responses = await generate_load(

View File

@ -413,6 +413,9 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: String,
#[clap(long, env)] #[clap(long, env)]
cors_allow_origin: Vec<String>, cors_allow_origin: Vec<String>,
#[clap(long, env)] #[clap(long, env)]
@ -489,6 +492,7 @@ fn shard_manager(
max_input_tokens: usize, max_input_tokens: usize,
lora_adapters: Option<String>, lora_adapters: Option<String>,
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
otlp_service_name: String,
log_level: LevelFilter, log_level: LevelFilter,
status_sender: mpsc::Sender<ShardStatus>, status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
@ -554,12 +558,16 @@ fn shard_manager(
(None, Some(factor)) => Some((RopeScaling::Linear, factor)), (None, Some(factor)) => Some((RopeScaling::Linear, factor)),
}; };
// OpenTelemetry // OpenTelemetry Endpoint
if let Some(otlp_endpoint) = otlp_endpoint { if let Some(otlp_endpoint) = otlp_endpoint {
shard_args.push("--otlp-endpoint".to_string()); shard_args.push("--otlp-endpoint".to_string());
shard_args.push(otlp_endpoint); shard_args.push(otlp_endpoint);
} }
// OpenTelemetry Service Name
shard_args.push("--otlp-service-name".to_string());
shard_args.push(otlp_service_name);
// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter. // In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
shard_args.push("--max-input-tokens".to_string()); shard_args.push("--max-input-tokens".to_string());
shard_args.push(max_input_tokens.to_string()); shard_args.push(max_input_tokens.to_string());
@ -598,7 +606,7 @@ fn shard_manager(
// Parse Inference API token // Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") { if let Ok(api_token) = env::var("HF_API_TOKEN") {
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) envs.push(("HF_TOKEN".into(), api_token.into()))
}; };
// Detect rope scaling // Detect rope scaling
@ -762,7 +770,10 @@ fn shutdown_shards(shutdown: Arc<AtomicBool>, shutdown_receiver: &mpsc::Receiver
fn num_cuda_devices() -> Option<usize> { fn num_cuda_devices() -> Option<usize> {
let devices = match env::var("CUDA_VISIBLE_DEVICES") { let devices = match env::var("CUDA_VISIBLE_DEVICES") {
Ok(devices) => devices, Ok(devices) => devices,
Err(_) => env::var("NVIDIA_VISIBLE_DEVICES").ok()?, Err(_) => match env::var("NVIDIA_VISIBLE_DEVICES") {
Ok(devices) => devices,
Err(_) => env::var("ZE_AFFINITY_MASK").ok()?,
},
}; };
let n_devices = devices.split(',').count(); let n_devices = devices.split(',').count();
Some(n_devices) Some(n_devices)
@ -835,9 +846,9 @@ fn find_num_shards(
let num_shard = match (sharded, num_shard) { let num_shard = match (sharded, num_shard) {
(Some(true), None) => { (Some(true), None) => {
// try to default to the number of available GPUs // try to default to the number of available GPUs
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES"); tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES/ZE_AFFINITY_MASK");
let n_devices = num_cuda_devices() let n_devices = num_cuda_devices()
.expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set"); .expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES/ZE_AFFINITY_MASK are not set");
if n_devices <= 1 { if n_devices <= 1 {
return Err(LauncherError::NotEnoughCUDADevices(format!( return Err(LauncherError::NotEnoughCUDADevices(format!(
"`sharded` is true but only found {n_devices} CUDA devices" "`sharded` is true but only found {n_devices} CUDA devices"
@ -936,7 +947,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// Parse Inference API token // Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") { if let Ok(api_token) = env::var("HF_API_TOKEN") {
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) envs.push(("HF_TOKEN".into(), api_token.into()))
}; };
// If args.weights_cache_override is some, pass it to the download process // If args.weights_cache_override is some, pass it to the download process
@ -1046,6 +1057,7 @@ fn spawn_shards(
let shutdown = shutdown.clone(); let shutdown = shutdown.clone();
let shutdown_sender = shutdown_sender.clone(); let shutdown_sender = shutdown_sender.clone();
let otlp_endpoint = args.otlp_endpoint.clone(); let otlp_endpoint = args.otlp_endpoint.clone();
let otlp_service_name = args.otlp_service_name.clone();
let quantize = args.quantize; let quantize = args.quantize;
let speculate = args.speculate; let speculate = args.speculate;
let dtype = args.dtype; let dtype = args.dtype;
@ -1087,6 +1099,7 @@ fn spawn_shards(
max_input_tokens, max_input_tokens,
lora_adapters, lora_adapters,
otlp_endpoint, otlp_endpoint,
otlp_service_name,
max_log_level, max_log_level,
status_sender, status_sender,
shutdown, shutdown,
@ -1220,6 +1233,11 @@ fn spawn_webserver(
router_args.push(otlp_endpoint); router_args.push(otlp_endpoint);
} }
// OpenTelemetry
let otlp_service_name = args.otlp_service_name;
router_args.push("--otlp-service-name".to_string());
router_args.push(otlp_service_name);
// CORS origins // CORS origins
for origin in args.cors_allow_origin.into_iter() { for origin in args.cors_allow_origin.into_iter() {
router_args.push("--cors-allow-origin".to_string()); router_args.push("--cors-allow-origin".to_string());
@ -1240,7 +1258,7 @@ fn spawn_webserver(
// Parse Inference API token // Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") { if let Ok(api_token) = env::var("HF_API_TOKEN") {
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) envs.push(("HF_TOKEN".into(), api_token.into()))
}; };
// Parse Compute type // Parse Compute type

View File

@ -576,7 +576,7 @@ impl ChatCompletion {
}; };
Self { Self {
id: String::new(), id: String::new(),
object: "text_completion".into(), object: "chat.completion".into(),
created, created,
model, model,
system_fingerprint, system_fingerprint,
@ -688,7 +688,7 @@ impl ChatCompletionChunk {
}; };
Self { Self {
id: String::new(), id: String::new(),
object: "text_completion".to_string(), object: "chat.completion.chunk".to_string(),
created, created,
model, model,
system_fingerprint, system_fingerprint,

View File

@ -65,6 +65,8 @@ struct Args {
json_output: bool, json_output: bool,
#[clap(long, env)] #[clap(long, env)]
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: String,
#[clap(long, env)] #[clap(long, env)]
cors_allow_origin: Option<Vec<String>>, cors_allow_origin: Option<Vec<String>>,
#[clap(long, env)] #[clap(long, env)]
@ -107,6 +109,7 @@ async fn main() -> Result<(), RouterError> {
validation_workers, validation_workers,
json_output, json_output,
otlp_endpoint, otlp_endpoint,
otlp_service_name,
cors_allow_origin, cors_allow_origin,
ngrok, ngrok,
ngrok_authtoken, ngrok_authtoken,
@ -117,7 +120,7 @@ async fn main() -> Result<(), RouterError> {
} = args; } = args;
// Launch Tokio runtime // Launch Tokio runtime
init_logging(otlp_endpoint, json_output); init_logging(otlp_endpoint, otlp_service_name, json_output);
// Validate args // Validate args
if max_input_tokens >= max_total_tokens { if max_input_tokens >= max_total_tokens {
@ -156,7 +159,9 @@ async fn main() -> Result<(), RouterError> {
}); });
// Parse Huggingface hub token // Parse Huggingface hub token
let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok(); let authorization_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
// Tokenizer instance // Tokenizer instance
// This will only be used to validate payloads // This will only be used to validate payloads
@ -367,10 +372,11 @@ async fn main() -> Result<(), RouterError> {
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT: /// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
/// - otlp_endpoint is an optional URL to an Open Telemetry collector /// - otlp_endpoint is an optional URL to an Open Telemetry collector
/// - otlp_service_name service name to appear in APM
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO) /// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT) /// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms) /// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) { fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
let mut layers = Vec::new(); let mut layers = Vec::new();
// STDOUT/STDERR layer // STDOUT/STDERR layer
@ -401,7 +407,7 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
trace::config() trace::config()
.with_resource(Resource::new(vec![KeyValue::new( .with_resource(Resource::new(vec![KeyValue::new(
"service.name", "service.name",
"text-generation-inference.router", otlp_service_name,
)])) )]))
.with_sampler(Sampler::AlwaysOn), .with_sampler(Sampler::AlwaysOn),
) )

File diff suppressed because it is too large Load Diff

View File

@ -42,6 +42,7 @@ def serve(
logger_level: str = "INFO", logger_level: str = "INFO",
json_output: bool = False, json_output: bool = False,
otlp_endpoint: Optional[str] = None, otlp_endpoint: Optional[str] = None,
otlp_service_name: str = "text-generation-inference.server",
max_input_tokens: Optional[int] = None, max_input_tokens: Optional[int] = None,
): ):
if sharded: if sharded:
@ -76,7 +77,7 @@ def serve(
# Setup OpenTelemetry distributed tracing # Setup OpenTelemetry distributed tracing
if otlp_endpoint is not None: if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint) setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None) lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)

View File

@ -7,7 +7,7 @@ if SYSTEM == "cuda":
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "xpu": elif SYSTEM == "ipex":
from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
else: else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")

View File

@ -1,5 +1,6 @@
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
import torch import torch
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
SUPPORTS_WINDOWING = False SUPPORTS_WINDOWING = False
@ -56,8 +57,6 @@ def paged_attention(
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
): ):
query = query.contiguous()
block_size = value_cache.shape[3]
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out, out,
query, query,
@ -67,7 +66,7 @@ def paged_attention(
softmax_scale, softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,
block_size, BLOCK_SIZE,
max_s, max_s,
None, None,
) )

View File

@ -82,18 +82,20 @@ elif SYSTEM == "rocm":
return super().forward(hidden_states), residual return super().forward(hidden_states), residual
elif SYSTEM == "xpu": elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
class FastLayerNorm(nn.LayerNorm): class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
res_out = hidden_states
out = ipex.llm.functional.add_layer_norm( out = ipex.llm.functional.add_layer_norm(
residual, hidden_states, self.weight, self.bias, self.eps, True residual,
hidden_states,
self.weight,
self.bias,
self.eps,
residual is not None,
) )
if residual is not None: return out, residual if residual is not None else hidden_states
res_out = residual
return out, res_out
class FastRMSNorm(nn.Module): class FastRMSNorm(nn.Module):
@ -109,19 +111,16 @@ class FastRMSNorm(nn.Module):
return cls(weight, eps) return cls(weight, eps)
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
if SYSTEM == "xpu": if SYSTEM == "ipex":
residual_out = hidden_states
out = ipex.llm.functional.add_rms_norm( out = ipex.llm.functional.add_rms_norm(
residual, residual,
hidden_states, hidden_states,
self.weight, self.weight,
None, None,
self.variance_epsilon, self.variance_epsilon,
True, residual is not None,
) )
if residual is not None: return out, residual if residual is not None else hidden_states
residual_out = residual
return out, residual_out
elif hidden_states.shape[-1] > 8192: elif hidden_states.shape[-1] > 8192:
if residual is not None: if residual is not None:
hidden_states += residual hidden_states += residual

View File

@ -9,7 +9,7 @@ if SYSTEM == "cuda":
import rotary_emb import rotary_emb
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from vllm._C import ops from vllm._C import ops
elif SYSTEM == "xpu": elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
@ -69,7 +69,7 @@ class PositionRotaryEmbedding(nn.Module):
# Inplace operation, updating query and key. # Inplace operation, updating query and key.
ops.rotary_embedding(query, key, head_size, cos, sin, True) ops.rotary_embedding(query, key, head_size, cos, sin, True)
elif SYSTEM == "xpu": elif SYSTEM == "ipex":
ipex.llm.functional.rotary_embedding( ipex.llm.functional.rotary_embedding(
query, key, sin, cos, query.size(-1), True query, key, sin, cos, query.size(-1), True
) )

View File

@ -3,6 +3,10 @@ from torch.nn import functional as F
from typing import Iterable, List from typing import Iterable, List
from text_generation_server.layers.linear import get_linear, FastLinear from text_generation_server.layers.linear import get_linear, FastLinear
from text_generation_server.layers.exl2 import Exl2Weight from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
class LayerConcat(torch.nn.Module): class LayerConcat(torch.nn.Module):
@ -96,10 +100,14 @@ class TensorParallelHead(SuperLayer):
local_out = gather_input.T local_out = gather_input.T
torch.mm(input, self.linear.weight.T, out=local_out) torch.mm(input, self.linear.weight.T, out=local_out)
if SYSTEM == "ipex":
torch.distributed.all_gather_into_tensor( ipex.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group world_out, gather_input, group=self.process_group
) )
else:
torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
if input.shape[0] == 1: if input.shape[0] == 1:
return world_out return world_out
@ -109,7 +117,10 @@ class TensorParallelHead(SuperLayer):
world_output = [ world_output = [
torch.empty_like(output) for _ in range(self.process_group.size()) torch.empty_like(output) for _ in range(self.process_group.size())
] ]
torch.distributed.all_gather(world_output, output, group=self.process_group) if SYSTEM == "ipex":
ipex.distributed.all_gather(world_output, output, group=self.process_group)
else:
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1) world_output = torch.cat(world_output, dim=-1)
return world_output return world_output
@ -206,7 +217,10 @@ class TensorParallelRowLinear(SuperLayer):
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor: def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
out = super().forward(input) out = super().forward(input)
if self.process_group.size() > 1 and reduce: if self.process_group.size() > 1 and reduce:
torch.distributed.all_reduce(out, group=self.process_group) if SYSTEM == "ipex":
ipex.distributed.all_reduce(out, group=self.process_group)
else:
torch.distributed.all_reduce(out, group=self.process_group)
return out return out
@ -243,5 +257,8 @@ class TensorParallelEmbedding(torch.nn.Module):
) )
out = torch.nn.functional.embedding(input, self.weight) out = torch.nn.functional.embedding(input, self.weight)
if self.reduce and self.process_group.size() > 1: if self.reduce and self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group) if SYSTEM == "ipex":
ipex.distributed.all_reduce(out, group=self.process_group)
else:
torch.distributed.all_reduce(out, group=self.process_group)
return out return out

View File

@ -22,7 +22,7 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any from typing import Optional, List, Tuple, Any
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM != "xpu": if SYSTEM != "ipex":
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (

View File

@ -26,7 +26,7 @@ import numpy as np
from torch import nn from torch import nn
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM != "xpu": if SYSTEM != "ipex":
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig

View File

@ -847,26 +847,43 @@ class FlashCausalLM(Model):
empty_cache() empty_cache()
element_size = torch.tensor([], dtype=dtype).element_size() element_size = torch.tensor([], dtype=dtype).element_size()
if SYSTEM == "xpu": if SYSTEM == "ipex" and device.type == "xpu":
x = 1 x = 1
else: else:
x = BLOCK_SIZE // element_size x = BLOCK_SIZE // element_size
self.kv_cache = [ if SYSTEM == "ipex" and device == torch.device("cpu"):
( self.kv_cache = [
torch.empty( (
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x), torch.empty(
dtype=dtype, (num_blocks, num_heads, BLOCK_SIZE, head_size),
device=device, dtype=dtype,
), device=device,
torch.empty( ),
(num_blocks, num_heads, head_size, BLOCK_SIZE), torch.empty(
dtype=dtype, (num_blocks, num_heads, BLOCK_SIZE, head_size),
device=device, dtype=dtype,
), device=device,
) ),
for _ in range(num_layers) )
] for _ in range(num_layers)
]
else:
self.kv_cache = [
(
torch.empty(
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, head_size, BLOCK_SIZE),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)

View File

@ -34,9 +34,13 @@ class FlashGPT2(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu": elif SYSTEM == "ipex":
device = torch.device(f"xpu:{rank}") if hasattr(torch, "xpu") and torch.xpu.is_available():
dtype = torch.float16 if dtype is None else dtype device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashGPT2 is only available on GPU") raise NotImplementedError("FlashGPT2 is only available on GPU")

View File

@ -48,9 +48,13 @@ class FlashLlama(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu": elif SYSTEM == "ipex":
device = torch.device(f"xpu:{rank}") if hasattr(torch, "xpu") and torch.xpu.is_available():
dtype = torch.float16 if dtype is None else dtype device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashLlama is only available on GPU") raise NotImplementedError("FlashLlama is only available on GPU")

View File

@ -50,9 +50,13 @@ class BaseFlashMistral(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu": elif SYSTEM == "ipex":
device = torch.device(f"xpu:{rank}") if hasattr(torch, "xpu") and torch.xpu.is_available():
dtype = torch.float16 if dtype is None else dtype device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashMistral is only available on GPU") raise NotImplementedError("FlashMistral is only available on GPU")

View File

@ -33,9 +33,13 @@ class FlashNeoXSharded(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu": elif SYSTEM == "ipex":
device = torch.device(f"xpu:{rank}") if hasattr(torch, "xpu") and torch.xpu.is_available():
dtype = torch.float16 if dtype is None else dtype device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashNeoX is only available on GPU") raise NotImplementedError("FlashNeoX is only available on GPU")

View File

@ -34,9 +34,13 @@ class FlashRWSharded(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu": elif SYSTEM == "ipex":
device = torch.device(f"xpu:{rank}") if hasattr(torch, "xpu") and torch.xpu.is_available():
dtype = torch.float16 if dtype is None else dtype device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashRW is only available on GPU") raise NotImplementedError("FlashRW is only available on GPU")

View File

@ -37,9 +37,13 @@ class FlashSantacoderSharded(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu": elif SYSTEM == "ipex":
device = torch.device(f"xpu:{rank}") if hasattr(torch, "xpu") and torch.xpu.is_available():
dtype = torch.float16 if dtype is None else dtype device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU") raise NotImplementedError("FlashSantacoderSharded is only available on GPU")

View File

@ -54,10 +54,8 @@ class UDSOpenTelemetryAioServerInterceptor(OpenTelemetryAioServerInterceptor):
) )
def setup_tracing(shard: int, otlp_endpoint: str): def setup_tracing(otlp_service_name: str, otlp_endpoint: str):
resource = Resource.create( resource = Resource.create(attributes={"service.name": otlp_service_name})
attributes={"service.name": f"text-generation-inference.server-{shard}"}
)
span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True) span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)
span_processor = BatchSpanProcessor(span_exporter) span_processor = BatchSpanProcessor(span_exporter)

View File

@ -3,6 +3,7 @@ import torch
from datetime import timedelta from datetime import timedelta
from loguru import logger from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM
# Tensor Parallelism settings # Tensor Parallelism settings
RANK = int(os.getenv("RANK", "0")) RANK = int(os.getenv("RANK", "0"))
@ -57,14 +58,7 @@ def initialize_torch_distributed():
options.is_high_priority_stream = True options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60) options._timeout = timedelta(seconds=60)
else: else:
try: backend = "gloo"
import oneccl_bindings_for_pytorch
backend = "ccl"
if os.getenv("CCL_WORKER_COUNT", None) is None:
os.environ["CCL_WORKER_COUNT"] = str(1)
except ImportError:
backend = "gloo"
options = None options = None
if WORLD_SIZE == 1: if WORLD_SIZE == 1:
@ -75,13 +69,24 @@ def initialize_torch_distributed():
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
# Call the init process. # Call the init process.
torch.distributed.init_process_group( if SYSTEM == "ipex":
backend=backend, import intel_extension_for_pytorch as ipex
world_size=WORLD_SIZE,
rank=RANK, ipex.distributed.init_process_group(
timeout=timedelta(seconds=60), backend="ccl",
pg_options=options, world_size=WORLD_SIZE,
) rank=RANK,
timeout=timedelta(seconds=60),
pg_options=options,
)
else:
torch.distributed.init_process_group(
backend=backend,
world_size=WORLD_SIZE,
rank=RANK,
timeout=timedelta(seconds=60),
pg_options=options,
)
else: else:
logger.warning("torch.distributed is already initialized.") logger.warning("torch.distributed is already initialized.")

View File

@ -1,14 +1,14 @@
import torch import torch
from loguru import logger from loguru import logger
import subprocess
def is_xpu_available(): def is_ipex_available():
try: try:
import intel_extension_for_pytorch import intel_extension_for_pytorch
except ImportError: except ImportError:
return False return False
return True
return hasattr(torch, "xpu") and torch.xpu.is_available()
def get_cuda_free_memory(device, memory_fraction): def get_cuda_free_memory(device, memory_fraction):
@ -19,11 +19,28 @@ def get_cuda_free_memory(device, memory_fraction):
def get_xpu_free_memory(device, memory_fraction): def get_xpu_free_memory(device, memory_fraction):
total_gpu_memory = torch.xpu.get_device_properties(device).total_memory total_memory = torch.xpu.get_device_properties(device).total_memory
free_memory = int(total_gpu_memory * 0.5) device_id = device.index
query = f"xpu-smi dump -d {device_id} -m 18 -n 1"
output = subprocess.check_output(query.split()).decode("utf-8").split("\n")
used_memory = float(output[1].split(",")[-1]) * 1024 * 1024
free_memory = int(total_memory * 0.95 - used_memory)
return free_memory return free_memory
def get_cpu_free_memory(device, memory_fraction):
import psutil
from text_generation_server.utils.dist import WORLD_SIZE
mem = psutil.virtual_memory()
free_memory = int(mem.available * 0.95 / WORLD_SIZE)
return free_memory
def noop(*args, **kwargs):
pass
SYSTEM = None SYSTEM = None
if torch.version.hip is not None: if torch.version.hip is not None:
SYSTEM = "rocm" SYSTEM = "rocm"
@ -35,18 +52,20 @@ elif torch.version.cuda is not None and torch.cuda.is_available():
empty_cache = torch.cuda.empty_cache empty_cache = torch.cuda.empty_cache
synchronize = torch.cuda.synchronize synchronize = torch.cuda.synchronize
get_free_memory = get_cuda_free_memory get_free_memory = get_cuda_free_memory
elif is_xpu_available(): elif is_ipex_available():
SYSTEM = "xpu" SYSTEM = "ipex"
empty_cache = torch.xpu.empty_cache if hasattr(torch, "xpu") and torch.xpu.is_available():
synchronize = torch.xpu.synchronize empty_cache = torch.xpu.empty_cache
get_free_memory = get_xpu_free_memory synchronize = torch.xpu.synchronize
get_free_memory = get_xpu_free_memory
else:
empty_cache = noop
synchronize = noop
get_free_memory = get_cpu_free_memory
else: else:
SYSTEM = "cpu" SYSTEM = "cpu"
def noop(*args, **kwargs):
pass
empty_cache = noop empty_cache = noop
synchronize = noop synchronize = noop
get_free_memory = noop get_free_memory = get_cpu_free_memory
logger.info(f"Detected system {SYSTEM}") logger.info(f"Detected system {SYSTEM}")