mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Merge branch 'main' into lora-internal
This commit is contained in:
commit
59575fe62a
6
.github/workflows/build.yaml
vendored
6
.github/workflows/build.yaml
vendored
@ -156,6 +156,8 @@ jobs:
|
|||||||
needs: build-and-push
|
needs: build-and-push
|
||||||
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
|
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
|
||||||
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
|
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
|
||||||
|
env:
|
||||||
|
PYTEST_FLAGS: ${{ github.ref == 'refs/heads/main' && '--release' || '' }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@ -178,6 +180,6 @@ jobs:
|
|||||||
export DOCKER_VOLUME=/mnt/cache
|
export DOCKER_VOLUME=/mnt/cache
|
||||||
export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}
|
export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}
|
||||||
export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }}
|
export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }}
|
||||||
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||||
echo $DOCKER_IMAGE
|
echo $DOCKER_IMAGE
|
||||||
pytest -s -vv integration-tests
|
pytest -s -vv integration-tests ${PYTEST_FLAGS}
|
||||||
|
2
.github/workflows/client-tests.yaml
vendored
2
.github/workflows/client-tests.yaml
vendored
@ -22,5 +22,5 @@ jobs:
|
|||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pip install pytest pytest-asyncio
|
pip install pytest pytest-asyncio
|
||||||
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||||
make python-client-tests
|
make python-client-tests
|
||||||
|
2
.github/workflows/integration_tests.yaml
vendored
2
.github/workflows/integration_tests.yaml
vendored
@ -37,5 +37,5 @@ jobs:
|
|||||||
export DOCKER_VOLUME=/mnt/cache
|
export DOCKER_VOLUME=/mnt/cache
|
||||||
export DOCKER_IMAGE=${{ inputs.docker_image }}
|
export DOCKER_IMAGE=${{ inputs.docker_image }}
|
||||||
export DOCKER_DEVICES=${{ inputs.docker_devices }}
|
export DOCKER_DEVICES=${{ inputs.docker_devices }}
|
||||||
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||||
pytest -s -vv integration-tests
|
pytest -s -vv integration-tests
|
||||||
|
70
.github/workflows/load_test.yaml
vendored
70
.github/workflows/load_test.yaml
vendored
@ -11,66 +11,24 @@ on:
|
|||||||
- 'main'
|
- 'main'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
start-runner:
|
|
||||||
name: Start self-hosted EC2 runner
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
env:
|
|
||||||
AWS_REGION: eu-central-1
|
|
||||||
EC2_AMI_ID: ami-0ab09c07cfd194259
|
|
||||||
EC2_INSTANCE_TYPE: g5.12xlarge
|
|
||||||
EC2_SUBNET_ID: subnet-988fd9f2,subnet-6f56db13,subnet-6a039326
|
|
||||||
EC2_SECURITY_GROUP: sg-072f92ae3082936c6
|
|
||||||
outputs:
|
|
||||||
label: ${{ steps.start-ec2-runner.outputs.label }}
|
|
||||||
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
|
|
||||||
steps:
|
|
||||||
- name: Configure AWS credentials
|
|
||||||
uses: aws-actions/configure-aws-credentials@v1
|
|
||||||
with:
|
|
||||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
|
||||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
|
||||||
aws-region: ${{ env.AWS_REGION }}
|
|
||||||
- name: Start EC2 runner
|
|
||||||
id: start-ec2-runner
|
|
||||||
uses: philschmid/philschmid-ec2-github-runner@main
|
|
||||||
with:
|
|
||||||
mode: start
|
|
||||||
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
|
||||||
ec2-image-id: ${{ env.EC2_AMI_ID }}
|
|
||||||
ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
|
|
||||||
subnet-id: ${{ env.EC2_SUBNET_ID }}
|
|
||||||
security-group-id: ${{ env.EC2_SECURITY_GROUP }}
|
|
||||||
aws-resource-tags: > # optional, requires additional permissions
|
|
||||||
[
|
|
||||||
{"Key": "Name", "Value": "ec2-tgi-github-runner"},
|
|
||||||
{"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
|
|
||||||
]
|
|
||||||
|
|
||||||
load-tests:
|
load-tests:
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
needs: start-runner # required to start the main job when the runner is ready
|
runs-on: [self-hosted, nvidia-gpu , multi-gpu, 4-a10, ci]
|
||||||
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
|
||||||
env:
|
env:
|
||||||
DOCKER_VOLUME: /cache
|
DOCKER_VOLUME: /cache
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
- name: Prepare disks
|
|
||||||
run: |
|
|
||||||
sudo mkfs -t ext4 /dev/nvme1n1
|
|
||||||
sudo mkdir ${{ env.DOCKER_VOLUME }}
|
|
||||||
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
|
|
||||||
|
|
||||||
- name: Install k6
|
- name: Install k6
|
||||||
run: |
|
run: |
|
||||||
curl https://github.com/grafana/k6/releases/download/v0.44.0/k6-v0.44.0-linux-amd64.tar.gz -L | tar xvz --strip-components 1
|
curl https://github.com/grafana/k6/releases/download/v0.44.0/k6-v0.44.0-linux-amd64.tar.gz -L | tar xvz --strip-components 1
|
||||||
|
|
||||||
- name: Start starcoder
|
- name: Start starcoder
|
||||||
run: |
|
run: |
|
||||||
docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v ${{ env.DOCKER_VOLUME }}:/data -e HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768
|
docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v /mnt/cache:/data -e HF_TOKEN=${{ secrets.HF_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768
|
||||||
sleep 10
|
sleep 10
|
||||||
wget --timeout 10 --retry-on-http-error --waitretry=1 --tries=240 http://localhost:3000/health
|
wget --timeout 10 --retry-on-http-error --waitretry=1 --tries=240 http://localhost:3000/health
|
||||||
|
|
||||||
@ -82,27 +40,3 @@ jobs:
|
|||||||
if: ${{ always() }}
|
if: ${{ always() }}
|
||||||
run: |
|
run: |
|
||||||
docker stop tgi-starcoder || true
|
docker stop tgi-starcoder || true
|
||||||
|
|
||||||
stop-runner:
|
|
||||||
name: Stop self-hosted EC2 runner
|
|
||||||
needs:
|
|
||||||
- start-runner
|
|
||||||
- load-tests
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
env:
|
|
||||||
AWS_REGION: eu-central-1
|
|
||||||
if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
|
|
||||||
steps:
|
|
||||||
- name: Configure AWS credentials
|
|
||||||
uses: aws-actions/configure-aws-credentials@v1
|
|
||||||
with:
|
|
||||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
|
||||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
|
||||||
aws-region: ${{ env.AWS_REGION }}
|
|
||||||
- name: Stop EC2 runner
|
|
||||||
uses: philschmid/philschmid-ec2-github-runner@main
|
|
||||||
with:
|
|
||||||
mode: stop
|
|
||||||
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
|
||||||
label: ${{ needs.start-runner.outputs.label }}
|
|
||||||
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
|
|
||||||
|
2
.github/workflows/tests.yaml
vendored
2
.github/workflows/tests.yaml
vendored
@ -72,7 +72,7 @@ jobs:
|
|||||||
- name: Run server tests
|
- name: Run server tests
|
||||||
run: |
|
run: |
|
||||||
pip install pytest
|
pip install pytest
|
||||||
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HF_TOKEN }}
|
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||||
pytest -s -vv server/tests
|
pytest -s -vv server/tests
|
||||||
- name: Pre-commit checks
|
- name: Pre-commit checks
|
||||||
run: |
|
run: |
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
ARG PLATFORM=xpu
|
||||||
|
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
@ -37,7 +39,8 @@ RUN cargo build --profile release-opt
|
|||||||
|
|
||||||
|
|
||||||
# Text Generation Inference base image for Intel
|
# Text Generation Inference base image for Intel
|
||||||
FROM intel/intel-extension-for-pytorch:2.1.30-xpu as base
|
|
||||||
|
FROM intel/intel-extension-for-pytorch:2.1.30-xpu as xpu
|
||||||
|
|
||||||
USER root
|
USER root
|
||||||
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
|
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
|
||||||
@ -49,7 +52,7 @@ RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dea
|
|||||||
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
|
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
|
||||||
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
|
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
|
||||||
|
|
||||||
RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build
|
RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build pciutils
|
||||||
|
|
||||||
# Text Generation Inference base env
|
# Text Generation Inference base env
|
||||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||||
@ -59,7 +62,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
|
|||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
|
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
|
||||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b group_rope origin/dev/gqa_rope
|
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
@ -89,8 +92,84 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
|
|||||||
# Install launcher
|
# Install launcher
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
# Final image
|
|
||||||
FROM base
|
|
||||||
|
|
||||||
|
# Text Generation Inference base image for Intel-cpu
|
||||||
|
FROM ubuntu:22.04 as cpu
|
||||||
|
|
||||||
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
|
curl \
|
||||||
|
ca-certificates \
|
||||||
|
make \
|
||||||
|
g++ \
|
||||||
|
git \
|
||||||
|
wget \
|
||||||
|
cmake
|
||||||
|
|
||||||
|
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||||
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
|
PORT=80
|
||||||
|
|
||||||
|
ARG MAMBA_VERSION=23.1.0-1
|
||||||
|
ARG PYTHON_VERSION='3.10.10'
|
||||||
|
# Automatically set by buildx
|
||||||
|
ARG TARGETPLATFORM
|
||||||
|
ENV PATH /opt/conda/bin:$PATH
|
||||||
|
|
||||||
|
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||||
|
# Install mamba
|
||||||
|
# translating Docker's TARGETPLATFORM into mamba arches
|
||||||
|
RUN case ${TARGETPLATFORM} in \
|
||||||
|
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
||||||
|
*) MAMBA_ARCH=x86_64 ;; \
|
||||||
|
esac && \
|
||||||
|
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
||||||
|
RUN chmod +x ~/mambaforge.sh && \
|
||||||
|
bash ~/mambaforge.sh -b -p /opt/conda && \
|
||||||
|
rm ~/mambaforge.sh
|
||||||
|
|
||||||
|
RUN conda install -c conda-forge gperftools mkl
|
||||||
|
|
||||||
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||||
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||||
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||||
|
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a
|
||||||
|
|
||||||
|
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131
|
||||||
|
|
||||||
|
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
|
||||||
|
|
||||||
|
RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
|
||||||
|
|
||||||
|
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so:/opt/conda/lib/libiomp5.so
|
||||||
|
ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
||||||
|
ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
||||||
|
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
|
||||||
|
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib
|
||||||
|
ENV KMP_BLOCKTIME=1
|
||||||
|
ENV KMP_TPAUSE=0
|
||||||
|
ENV KMP_FORKJOIN_BARRIER_PATTERN=dist,dist
|
||||||
|
ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist
|
||||||
|
ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist
|
||||||
|
|
||||||
|
# Install server
|
||||||
|
COPY proto proto
|
||||||
|
COPY server server
|
||||||
|
COPY server/Makefile server/Makefile
|
||||||
|
RUN cd server && \
|
||||||
|
make gen-server && \
|
||||||
|
pip install -r requirements_intel.txt && \
|
||||||
|
pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
||||||
|
|
||||||
|
# Install benchmarker
|
||||||
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
|
# Install router
|
||||||
|
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
|
||||||
|
# Install launcher
|
||||||
|
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
|
FROM ${PLATFORM} as final
|
||||||
ENTRYPOINT ["text-generation-launcher"]
|
ENTRYPOINT ["text-generation-launcher"]
|
||||||
CMD ["--json-output"]
|
CMD ["--json-output"]
|
||||||
|
@ -105,14 +105,14 @@ The Swagger UI is also available at: [https://huggingface.github.io/text-generat
|
|||||||
|
|
||||||
### Using a private or gated model
|
### Using a private or gated model
|
||||||
|
|
||||||
You have the option to utilize the `HUGGING_FACE_HUB_TOKEN` environment variable for configuring the token employed by
|
You have the option to utilize the `HF_TOKEN` environment variable for configuring the token employed by
|
||||||
`text-generation-inference`. This allows you to gain access to protected resources.
|
`text-generation-inference`. This allows you to gain access to protected resources.
|
||||||
|
|
||||||
For example, if you want to serve the gated Llama V2 model variants:
|
For example, if you want to serve the gated Llama V2 model variants:
|
||||||
|
|
||||||
1. Go to https://huggingface.co/settings/tokens
|
1. Go to https://huggingface.co/settings/tokens
|
||||||
2. Copy your cli READ token
|
2. Copy your cli READ token
|
||||||
3. Export `HUGGING_FACE_HUB_TOKEN=<your cli READ token>`
|
3. Export `HF_TOKEN=<your cli READ token>`
|
||||||
|
|
||||||
or with Docker:
|
or with Docker:
|
||||||
|
|
||||||
@ -121,7 +121,7 @@ model=meta-llama/Llama-2-7b-chat-hf
|
|||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
token=<your cli READ token>
|
token=<your cli READ token>
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model
|
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
### A note on Shared Memory (shm)
|
### A note on Shared Memory (shm)
|
||||||
@ -153,7 +153,8 @@ this will impact performance.
|
|||||||
### Distributed Tracing
|
### Distributed Tracing
|
||||||
|
|
||||||
`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature
|
`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature
|
||||||
by setting the address to an OTLP collector with the `--otlp-endpoint` argument.
|
by setting the address to an OTLP collector with the `--otlp-endpoint` argument. The default service name can be
|
||||||
|
overridden with the `--otlp-service-name` argument
|
||||||
|
|
||||||
### Architecture
|
### Architecture
|
||||||
|
|
||||||
|
@ -147,7 +147,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
tracing::info!("Downloading tokenizer");
|
tracing::info!("Downloading tokenizer");
|
||||||
|
|
||||||
// Parse Huggingface hub token
|
// Parse Huggingface hub token
|
||||||
let auth_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();
|
let auth_token = std::env::var("HF_TOKEN")
|
||||||
|
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
||||||
|
.ok();
|
||||||
|
|
||||||
// Download and instantiate tokenizer
|
// Download and instantiate tokenizer
|
||||||
// We need to download it outside of the Tokio runtime
|
// We need to download it outside of the Tokio runtime
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator, ConfigDict
|
||||||
from typing import Optional, List, Union, Any
|
from typing import Optional, List, Union, Any
|
||||||
|
|
||||||
from text_generation.errors import ValidationError
|
from text_generation.errors import ValidationError
|
||||||
@ -452,5 +452,9 @@ class StreamResponse(BaseModel):
|
|||||||
|
|
||||||
# Inference API currently deployed model
|
# Inference API currently deployed model
|
||||||
class DeployedModel(BaseModel):
|
class DeployedModel(BaseModel):
|
||||||
|
# Disable warning for use of `model_` prefix in `model_id`. Be mindful about adding members
|
||||||
|
# with model_ prefixes, since this disables guardrails for colliding fields:
|
||||||
|
# https://github.com/pydantic/pydantic/issues/9177
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
model_id: str
|
model_id: str
|
||||||
sha: str
|
sha: str
|
||||||
|
@ -70,6 +70,8 @@ Options:
|
|||||||
[env: JSON_OUTPUT=]
|
[env: JSON_OUTPUT=]
|
||||||
--otlp-endpoint <OTLP_ENDPOINT>
|
--otlp-endpoint <OTLP_ENDPOINT>
|
||||||
[env: OTLP_ENDPOINT=]
|
[env: OTLP_ENDPOINT=]
|
||||||
|
--otlp-service-name <OTLP_SERVICE_NAME>
|
||||||
|
[env: OTLP_SERVICE_NAME=]
|
||||||
--cors-allow-origin <CORS_ALLOW_ORIGIN>
|
--cors-allow-origin <CORS_ALLOW_ORIGIN>
|
||||||
[env: CORS_ALLOW_ORIGIN=]
|
[env: CORS_ALLOW_ORIGIN=]
|
||||||
--ngrok
|
--ngrok
|
||||||
@ -138,6 +140,8 @@ Serve's command line parameters on the TGI repository are these:
|
|||||||
│ --logger-level TEXT [default: INFO] │
|
│ --logger-level TEXT [default: INFO] │
|
||||||
│ --json-output --no-json-output [default: no-json-output] │
|
│ --json-output --no-json-output [default: no-json-output] │
|
||||||
│ --otlp-endpoint TEXT [default: None] │
|
│ --otlp-endpoint TEXT [default: None] │
|
||||||
|
│ --otlp-service-name TEXT [default: │
|
||||||
|
│ text-generation-inference...│
|
||||||
│ --help Show this message and exit. │
|
│ --help Show this message and exit. │
|
||||||
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯
|
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯
|
||||||
```
|
```
|
||||||
|
@ -2,13 +2,13 @@
|
|||||||
|
|
||||||
If the model you wish to serve is behind gated access or the model repository on Hugging Face Hub is private, and you have access to the model, you can provide your Hugging Face Hub access token. You can generate and copy a read token from [Hugging Face Hub tokens page](https://huggingface.co/settings/tokens)
|
If the model you wish to serve is behind gated access or the model repository on Hugging Face Hub is private, and you have access to the model, you can provide your Hugging Face Hub access token. You can generate and copy a read token from [Hugging Face Hub tokens page](https://huggingface.co/settings/tokens)
|
||||||
|
|
||||||
If you're using the CLI, set the `HUGGING_FACE_HUB_TOKEN` environment variable. For example:
|
If you're using the CLI, set the `HF_TOKEN` environment variable. For example:
|
||||||
|
|
||||||
```
|
```
|
||||||
export HUGGING_FACE_HUB_TOKEN=<YOUR READ TOKEN>
|
export HF_TOKEN=<YOUR READ TOKEN>
|
||||||
```
|
```
|
||||||
|
|
||||||
If you would like to do it through Docker, you can provide your token by specifying `HUGGING_FACE_HUB_TOKEN` as shown below.
|
If you would like to do it through Docker, you can provide your token by specifying `HF_TOKEN` as shown below.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
model=meta-llama/Llama-2-7b-chat-hf
|
model=meta-llama/Llama-2-7b-chat-hf
|
||||||
@ -17,7 +17,7 @@ token=<your READ token>
|
|||||||
|
|
||||||
docker run --gpus all \
|
docker run --gpus all \
|
||||||
--shm-size 1g \
|
--shm-size 1g \
|
||||||
-e HUGGING_FACE_HUB_TOKEN=$token \
|
-e HF_TOKEN=$token \
|
||||||
-p 8080:80 \
|
-p 8080:80 \
|
||||||
-v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.4 \
|
-v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.4 \
|
||||||
--model-id $model
|
--model-id $model
|
||||||
|
@ -336,6 +336,13 @@ Options:
|
|||||||
--otlp-endpoint <OTLP_ENDPOINT>
|
--otlp-endpoint <OTLP_ENDPOINT>
|
||||||
[env: OTLP_ENDPOINT=]
|
[env: OTLP_ENDPOINT=]
|
||||||
|
|
||||||
|
```
|
||||||
|
## OTLP_SERVICE_NAME
|
||||||
|
```shell
|
||||||
|
--otlp-service-name <OTLP_SERVICE_NAME>
|
||||||
|
[env: OTLP_SERVICE_NAME=]
|
||||||
|
[default: text-generation-inference.router]
|
||||||
|
|
||||||
```
|
```
|
||||||
## CORS_ALLOW_ORIGIN
|
## CORS_ALLOW_ORIGIN
|
||||||
```shell
|
```shell
|
||||||
|
@ -1,42 +1,62 @@
|
|||||||
import sys
|
|
||||||
import subprocess
|
|
||||||
import contextlib
|
|
||||||
import pytest
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import contextlib
|
||||||
import docker
|
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
import random
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from docker.errors import NotFound
|
import docker
|
||||||
from typing import Optional, List, Dict
|
import pytest
|
||||||
from syrupy.extensions.json import JSONSnapshotExtension
|
|
||||||
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
|
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
|
||||||
|
from docker.errors import NotFound
|
||||||
|
from syrupy.extensions.json import JSONSnapshotExtension
|
||||||
from text_generation import AsyncClient
|
from text_generation import AsyncClient
|
||||||
from text_generation.types import (
|
from text_generation.types import (
|
||||||
Response,
|
|
||||||
Details,
|
|
||||||
InputToken,
|
|
||||||
Token,
|
|
||||||
BestOfSequence,
|
BestOfSequence,
|
||||||
Grammar,
|
|
||||||
ChatComplete,
|
ChatComplete,
|
||||||
ChatCompletionChunk,
|
ChatCompletionChunk,
|
||||||
ChatCompletionComplete,
|
ChatCompletionComplete,
|
||||||
Completion,
|
Completion,
|
||||||
|
Details,
|
||||||
|
Grammar,
|
||||||
|
InputToken,
|
||||||
|
Response,
|
||||||
|
Token,
|
||||||
)
|
)
|
||||||
|
|
||||||
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
||||||
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
|
HF_TOKEN = os.getenv("HF_TOKEN", None)
|
||||||
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
|
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
|
||||||
DOCKER_DEVICES = os.getenv("DOCKER_DEVICES")
|
DOCKER_DEVICES = os.getenv("DOCKER_DEVICES")
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
parser.addoption(
|
||||||
|
"--release", action="store_true", default=False, help="run release tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config):
|
||||||
|
config.addinivalue_line("markers", "release: mark test as a release-only test")
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_collection_modifyitems(config, items):
|
||||||
|
if config.getoption("--release"):
|
||||||
|
# --release given in cli: do not skip release tests
|
||||||
|
return
|
||||||
|
skip_release = pytest.mark.skip(reason="need --release option to run")
|
||||||
|
for item in items:
|
||||||
|
if "release" in item.keywords:
|
||||||
|
item.add_marker(skip_release)
|
||||||
|
|
||||||
|
|
||||||
class ResponseComparator(JSONSnapshotExtension):
|
class ResponseComparator(JSONSnapshotExtension):
|
||||||
rtol = 0.2
|
rtol = 0.2
|
||||||
ignore_logprob = False
|
ignore_logprob = False
|
||||||
@ -447,8 +467,8 @@ def launcher(event_loop):
|
|||||||
if not use_flash_attention:
|
if not use_flash_attention:
|
||||||
env["USE_FLASH_ATTENTION"] = "false"
|
env["USE_FLASH_ATTENTION"] = "false"
|
||||||
|
|
||||||
if HUGGING_FACE_HUB_TOKEN is not None:
|
if HF_TOKEN is not None:
|
||||||
env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN
|
env["HF_TOKEN"] = HF_TOKEN
|
||||||
|
|
||||||
volumes = []
|
volumes = []
|
||||||
if DOCKER_VOLUME:
|
if DOCKER_VOLUME:
|
||||||
|
@ -13,6 +13,7 @@ async def bloom_560(bloom_560_handle):
|
|||||||
return bloom_560_handle.client
|
return bloom_560_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bloom_560m(bloom_560, response_snapshot):
|
async def test_bloom_560m(bloom_560, response_snapshot):
|
||||||
response = await bloom_560.generate(
|
response = await bloom_560.generate(
|
||||||
@ -27,6 +28,7 @@ async def test_bloom_560m(bloom_560, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bloom_560m_all_params(bloom_560, response_snapshot):
|
async def test_bloom_560m_all_params(bloom_560, response_snapshot):
|
||||||
response = await bloom_560.generate(
|
response = await bloom_560.generate(
|
||||||
@ -49,6 +51,7 @@ async def test_bloom_560m_all_params(bloom_560, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bloom_560m_load(bloom_560, generate_load, response_snapshot):
|
async def test_bloom_560m_load(bloom_560, generate_load, response_snapshot):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
|
@ -13,6 +13,7 @@ async def bloom_560m_sharded(bloom_560m_sharded_handle):
|
|||||||
return bloom_560m_sharded_handle.client
|
return bloom_560m_sharded_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
|
async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
|
||||||
response = await bloom_560m_sharded.generate(
|
response = await bloom_560m_sharded.generate(
|
||||||
@ -27,6 +28,7 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bloom_560m_sharded_load(
|
async def test_bloom_560m_sharded_load(
|
||||||
bloom_560m_sharded, generate_load, response_snapshot
|
bloom_560m_sharded, generate_load, response_snapshot
|
||||||
|
@ -26,6 +26,7 @@ async def flash_llama_completion(flash_llama_completion_handle):
|
|||||||
# method for it. Instead, we use the `requests` library to make the HTTP request directly.
|
# method for it. Instead, we use the `requests` library to make the HTTP request directly.
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
def test_flash_llama_completion_single_prompt(
|
def test_flash_llama_completion_single_prompt(
|
||||||
flash_llama_completion, response_snapshot
|
flash_llama_completion, response_snapshot
|
||||||
):
|
):
|
||||||
@ -46,6 +47,7 @@ def test_flash_llama_completion_single_prompt(
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{flash_llama_completion.base_url}/v1/completions",
|
f"{flash_llama_completion.base_url}/v1/completions",
|
||||||
@ -68,6 +70,7 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
async def test_flash_llama_completion_many_prompts_stream(
|
async def test_flash_llama_completion_many_prompts_stream(
|
||||||
flash_llama_completion, response_snapshot
|
flash_llama_completion, response_snapshot
|
||||||
):
|
):
|
||||||
|
@ -17,6 +17,7 @@ async def flash_llama_awq(flash_llama_awq_handle):
|
|||||||
return flash_llama_awq_handle.client
|
return flash_llama_awq_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
|
async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
|
||||||
response = await flash_llama_awq.generate(
|
response = await flash_llama_awq.generate(
|
||||||
@ -31,6 +32,7 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
|
async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
|
||||||
response = await flash_llama_awq.generate(
|
response = await flash_llama_awq.generate(
|
||||||
@ -52,6 +54,7 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot):
|
async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
|
@ -17,6 +17,7 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
|
|||||||
return flash_llama_awq_handle_sharded.client
|
return flash_llama_awq_handle_sharded.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
|
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
|
||||||
response = await flash_llama_awq_sharded.generate(
|
response = await flash_llama_awq_sharded.generate(
|
||||||
@ -31,6 +32,7 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_llama_awq_load_sharded(
|
async def test_flash_llama_awq_load_sharded(
|
||||||
flash_llama_awq_sharded, generate_load, response_snapshot
|
flash_llama_awq_sharded, generate_load, response_snapshot
|
||||||
|
@ -13,6 +13,7 @@ async def flash_falcon(flash_falcon_handle):
|
|||||||
return flash_falcon_handle.client
|
return flash_falcon_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_falcon(flash_falcon, response_snapshot):
|
async def test_flash_falcon(flash_falcon, response_snapshot):
|
||||||
@ -26,6 +27,7 @@ async def test_flash_falcon(flash_falcon, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
|
async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
|
||||||
@ -49,6 +51,7 @@ async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_falcon_load(flash_falcon, generate_load, response_snapshot):
|
async def test_flash_falcon_load(flash_falcon, generate_load, response_snapshot):
|
||||||
|
@ -13,6 +13,7 @@ async def flash_gemma(flash_gemma_handle):
|
|||||||
return flash_gemma_handle.client
|
return flash_gemma_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_gemma(flash_gemma, response_snapshot):
|
async def test_flash_gemma(flash_gemma, response_snapshot):
|
||||||
@ -24,6 +25,7 @@ async def test_flash_gemma(flash_gemma, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
|
async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
|
||||||
@ -47,6 +49,7 @@ async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot):
|
async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot):
|
||||||
|
@ -13,6 +13,7 @@ async def flash_gemma_gptq(flash_gemma_gptq_handle):
|
|||||||
return flash_gemma_gptq_handle.client
|
return flash_gemma_gptq_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapshot):
|
async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapshot):
|
||||||
@ -24,6 +25,7 @@ async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapsh
|
|||||||
assert response == ignore_logprob_response_snapshot
|
assert response == ignore_logprob_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_gemma_gptq_all_params(
|
async def test_flash_gemma_gptq_all_params(
|
||||||
@ -49,6 +51,7 @@ async def test_flash_gemma_gptq_all_params(
|
|||||||
assert response == ignore_logprob_response_snapshot
|
assert response == ignore_logprob_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_gemma_gptq_load(
|
async def test_flash_gemma_gptq_load(
|
||||||
|
@ -13,6 +13,7 @@ async def flash_gpt2(flash_gpt2_handle):
|
|||||||
return flash_gpt2_handle.client
|
return flash_gpt2_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_gpt2(flash_gpt2, response_snapshot):
|
async def test_flash_gpt2(flash_gpt2, response_snapshot):
|
||||||
response = await flash_gpt2.generate(
|
response = await flash_gpt2.generate(
|
||||||
@ -25,6 +26,7 @@ async def test_flash_gpt2(flash_gpt2, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_gpt2_load(flash_gpt2, generate_load, response_snapshot):
|
async def test_flash_gpt2_load(flash_gpt2, generate_load, response_snapshot):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
|
@ -21,6 +21,7 @@ async def flash_llama_exl2(flash_llama_exl2_handle):
|
|||||||
return flash_llama_exl2_handle.client
|
return flash_llama_exl2_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
|
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
|
||||||
@ -32,6 +33,7 @@ async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapsh
|
|||||||
assert response == ignore_logprob_response_snapshot
|
assert response == ignore_logprob_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_exl2_all_params(
|
async def test_flash_llama_exl2_all_params(
|
||||||
@ -58,6 +60,7 @@ async def test_flash_llama_exl2_all_params(
|
|||||||
assert response == ignore_logprob_response_snapshot
|
assert response == ignore_logprob_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_exl2_load(
|
async def test_flash_llama_exl2_load(
|
||||||
|
@ -13,6 +13,7 @@ async def flash_llama_gptq(flash_llama_gptq_handle):
|
|||||||
return flash_llama_gptq_handle.client
|
return flash_llama_gptq_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):
|
async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):
|
||||||
@ -24,6 +25,7 @@ async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
|
async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
|
||||||
@ -46,6 +48,7 @@ async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_gptq_load(
|
async def test_flash_llama_gptq_load(
|
||||||
|
@ -15,6 +15,7 @@ async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle):
|
|||||||
return flash_llama_gptq_marlin_handle.client
|
return flash_llama_gptq_marlin_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot):
|
async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot):
|
||||||
@ -26,6 +27,7 @@ async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapsho
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_gptq_marlin_all_params(
|
async def test_flash_llama_gptq_marlin_all_params(
|
||||||
@ -50,6 +52,7 @@ async def test_flash_llama_gptq_marlin_all_params(
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_gptq_marlin_load(
|
async def test_flash_llama_gptq_marlin_load(
|
||||||
|
@ -15,6 +15,7 @@ async def flash_llama_marlin(flash_llama_marlin_handle):
|
|||||||
return flash_llama_marlin_handle.client
|
return flash_llama_marlin_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
|
async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
|
||||||
@ -26,6 +27,7 @@ async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapshot):
|
async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapshot):
|
||||||
@ -48,6 +50,7 @@ async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapsh
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_marlin_load(
|
async def test_flash_llama_marlin_load(
|
||||||
|
@ -13,6 +13,7 @@ async def flash_neox(flash_neox_handle):
|
|||||||
return flash_neox_handle.client
|
return flash_neox_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.skip
|
@pytest.mark.skip
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_neox(flash_neox, response_snapshot):
|
async def test_flash_neox(flash_neox, response_snapshot):
|
||||||
@ -26,6 +27,7 @@ async def test_flash_neox(flash_neox, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.skip
|
@pytest.mark.skip
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):
|
async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):
|
||||||
|
@ -13,6 +13,7 @@ async def flash_neox_sharded(flash_neox_sharded_handle):
|
|||||||
return flash_neox_sharded_handle.client
|
return flash_neox_sharded_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_neox(flash_neox_sharded, response_snapshot):
|
async def test_flash_neox(flash_neox_sharded, response_snapshot):
|
||||||
response = await flash_neox_sharded.generate(
|
response = await flash_neox_sharded.generate(
|
||||||
@ -25,6 +26,7 @@ async def test_flash_neox(flash_neox_sharded, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_neox_load(flash_neox_sharded, generate_load, response_snapshot):
|
async def test_flash_neox_load(flash_neox_sharded, generate_load, response_snapshot):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
|
@ -34,6 +34,7 @@ def get_cow_beach():
|
|||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
|
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
|
||||||
@ -45,6 +46,7 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot):
|
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot):
|
||||||
|
@ -13,6 +13,7 @@ async def flash_phi(flash_phi_handle):
|
|||||||
return flash_phi_handle.client
|
return flash_phi_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_phi(flash_phi, response_snapshot):
|
async def test_flash_phi(flash_phi, response_snapshot):
|
||||||
response = await flash_phi.generate(
|
response = await flash_phi.generate(
|
||||||
@ -24,6 +25,7 @@ async def test_flash_phi(flash_phi, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
||||||
response = await flash_phi.generate(
|
response = await flash_phi.generate(
|
||||||
@ -47,6 +49,7 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
|
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
|
||||||
responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4)
|
responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4)
|
||||||
|
@ -13,6 +13,7 @@ async def flash_qwen2(flash_qwen2_handle):
|
|||||||
return flash_qwen2_handle.client
|
return flash_qwen2_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_qwen2(flash_qwen2, response_snapshot):
|
async def test_flash_qwen2(flash_qwen2, response_snapshot):
|
||||||
response = await flash_qwen2.generate(
|
response = await flash_qwen2.generate(
|
||||||
@ -24,6 +25,7 @@ async def test_flash_qwen2(flash_qwen2, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot):
|
async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot):
|
||||||
response = await flash_qwen2.generate(
|
response = await flash_qwen2.generate(
|
||||||
@ -46,6 +48,7 @@ async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_qwen2_load(flash_qwen2, generate_load, response_snapshot):
|
async def test_flash_qwen2_load(flash_qwen2, generate_load, response_snapshot):
|
||||||
responses = await generate_load(flash_qwen2, "Test request", max_new_tokens=10, n=4)
|
responses = await generate_load(flash_qwen2, "Test request", max_new_tokens=10, n=4)
|
||||||
|
@ -13,6 +13,7 @@ async def flash_santacoder(flash_santacoder_handle):
|
|||||||
return flash_santacoder_handle.client
|
return flash_santacoder_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_santacoder(flash_santacoder, response_snapshot):
|
async def test_flash_santacoder(flash_santacoder, response_snapshot):
|
||||||
response = await flash_santacoder.generate(
|
response = await flash_santacoder.generate(
|
||||||
@ -23,6 +24,7 @@ async def test_flash_santacoder(flash_santacoder, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_santacoder_load(
|
async def test_flash_santacoder_load(
|
||||||
flash_santacoder, generate_load, response_snapshot
|
flash_santacoder, generate_load, response_snapshot
|
||||||
|
@ -13,6 +13,7 @@ async def flash_starcoder(flash_starcoder_handle):
|
|||||||
return flash_starcoder_handle.client
|
return flash_starcoder_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder(flash_starcoder, response_snapshot):
|
async def test_flash_starcoder(flash_starcoder, response_snapshot):
|
||||||
@ -24,6 +25,7 @@ async def test_flash_starcoder(flash_starcoder, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot):
|
async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot):
|
||||||
@ -40,6 +42,7 @@ async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder_load(flash_starcoder, generate_load, response_snapshot):
|
async def test_flash_starcoder_load(flash_starcoder, generate_load, response_snapshot):
|
||||||
|
@ -13,6 +13,7 @@ async def flash_starcoder2(flash_starcoder2_handle):
|
|||||||
return flash_starcoder2_handle.client
|
return flash_starcoder2_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
|
async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
|
||||||
@ -24,6 +25,7 @@ async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot):
|
async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot):
|
||||||
@ -40,6 +42,7 @@ async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapsh
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder2_load(
|
async def test_flash_starcoder2_load(
|
||||||
|
@ -13,6 +13,7 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
|
|||||||
return flash_starcoder_gptq_handle.client
|
return flash_starcoder_gptq_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot):
|
async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot):
|
||||||
response = await flash_starcoder_gptq.generate(
|
response = await flash_starcoder_gptq.generate(
|
||||||
@ -24,6 +25,7 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap
|
|||||||
assert response == generous_response_snapshot
|
assert response == generous_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_starcoder_gptq_default_params(
|
async def test_flash_starcoder_gptq_default_params(
|
||||||
flash_starcoder_gptq, generous_response_snapshot
|
flash_starcoder_gptq, generous_response_snapshot
|
||||||
@ -40,6 +42,7 @@ async def test_flash_starcoder_gptq_default_params(
|
|||||||
assert response == generous_response_snapshot
|
assert response == generous_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_starcoder_gptq_load(
|
async def test_flash_starcoder_gptq_load(
|
||||||
flash_starcoder_gptq, generate_load, generous_response_snapshot
|
flash_starcoder_gptq, generate_load, generous_response_snapshot
|
||||||
|
@ -21,6 +21,7 @@ async def non_flash_llama_grammar(non_flash_llama_grammar_handle):
|
|||||||
return non_flash_llama_grammar_handle.client
|
return non_flash_llama_grammar_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.skip
|
@pytest.mark.skip
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_non_flash_llama_grammar_json(non_flash_llama_grammar, response_snapshot):
|
async def test_non_flash_llama_grammar_json(non_flash_llama_grammar, response_snapshot):
|
||||||
|
@ -22,6 +22,7 @@ async def llama_grammar(llama_grammar_handle):
|
|||||||
return llama_grammar_handle.client
|
return llama_grammar_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
|
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
|
||||||
|
|
||||||
@ -62,6 +63,7 @@ async def test_grammar_response_format_llama_json(llama_grammar, response_snapsh
|
|||||||
assert chat_completion == response_snapshot
|
assert chat_completion == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_grammar_response_format_llama_error_if_tools_not_installed(
|
async def test_grammar_response_format_llama_error_if_tools_not_installed(
|
||||||
llama_grammar,
|
llama_grammar,
|
||||||
|
@ -45,6 +45,7 @@ async def test_idefics(idefics, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_idefics_two_images(idefics, response_snapshot):
|
async def test_idefics_two_images(idefics, response_snapshot):
|
||||||
@ -60,6 +61,7 @@ async def test_idefics_two_images(idefics, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_idefics_load(idefics, generate_load, response_snapshot):
|
async def test_idefics_load(idefics, generate_load, response_snapshot):
|
||||||
chicken = get_chicken()
|
chicken = get_chicken()
|
||||||
|
@ -26,6 +26,7 @@ async def flash_llava_next(flash_llava_next_handle):
|
|||||||
return flash_llava_next_handle.client
|
return flash_llava_next_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
|
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
|
||||||
@ -41,6 +42,7 @@ async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
|
async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
|
||||||
@ -64,6 +66,7 @@ async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llava_next_load(
|
async def test_flash_llava_next_load(
|
||||||
|
@ -13,6 +13,7 @@ async def fused_kernel_mamba(fused_kernel_mamba_handle):
|
|||||||
return fused_kernel_mamba_handle.client
|
return fused_kernel_mamba_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mamba(fused_kernel_mamba, response_snapshot):
|
async def test_mamba(fused_kernel_mamba, response_snapshot):
|
||||||
response = await fused_kernel_mamba.generate(
|
response = await fused_kernel_mamba.generate(
|
||||||
@ -24,6 +25,7 @@ async def test_mamba(fused_kernel_mamba, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
||||||
response = await fused_kernel_mamba.generate(
|
response = await fused_kernel_mamba.generate(
|
||||||
@ -50,6 +52,7 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mamba_load(
|
async def test_mamba_load(
|
||||||
fused_kernel_mamba, generate_load, generous_response_snapshot
|
fused_kernel_mamba, generate_load, generous_response_snapshot
|
||||||
|
@ -13,6 +13,7 @@ async def mpt_sharded(mpt_sharded_handle):
|
|||||||
return mpt_sharded_handle.client
|
return mpt_sharded_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mpt(mpt_sharded, response_snapshot):
|
async def test_mpt(mpt_sharded, response_snapshot):
|
||||||
response = await mpt_sharded.generate(
|
response = await mpt_sharded.generate(
|
||||||
@ -29,6 +30,7 @@ async def test_mpt(mpt_sharded, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mpt_load(mpt_sharded, generate_load, response_snapshot):
|
async def test_mpt_load(mpt_sharded, generate_load, response_snapshot):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
|
@ -13,6 +13,7 @@ async def mt0_base(mt0_base_handle):
|
|||||||
return mt0_base_handle.client
|
return mt0_base_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mt0_base(mt0_base, response_snapshot):
|
async def test_mt0_base(mt0_base, response_snapshot):
|
||||||
response = await mt0_base.generate(
|
response = await mt0_base.generate(
|
||||||
@ -27,6 +28,7 @@ async def test_mt0_base(mt0_base, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mt0_base_all_params(mt0_base, response_snapshot):
|
async def test_mt0_base_all_params(mt0_base, response_snapshot):
|
||||||
response = await mt0_base.generate(
|
response = await mt0_base.generate(
|
||||||
@ -49,6 +51,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mt0_base_load(mt0_base, generate_load, response_snapshot):
|
async def test_mt0_base_load(mt0_base, generate_load, response_snapshot):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
|
@ -15,6 +15,7 @@ async def neox(neox_handle):
|
|||||||
return neox_handle.client
|
return neox_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.skip
|
@pytest.mark.skip
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_neox(neox, response_snapshot):
|
async def test_neox(neox, response_snapshot):
|
||||||
@ -28,6 +29,7 @@ async def test_neox(neox, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.skip
|
@pytest.mark.skip
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_neox_load(neox, generate_load, response_snapshot):
|
async def test_neox_load(neox, generate_load, response_snapshot):
|
||||||
|
@ -15,6 +15,7 @@ async def neox_sharded(neox_sharded_handle):
|
|||||||
return neox_sharded_handle.client
|
return neox_sharded_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.skip
|
@pytest.mark.skip
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_neox(neox_sharded, response_snapshot):
|
async def test_neox(neox_sharded, response_snapshot):
|
||||||
@ -28,6 +29,7 @@ async def test_neox(neox_sharded, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.skip
|
@pytest.mark.skip
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_neox_load(neox_sharded, generate_load, response_snapshot):
|
async def test_neox_load(neox_sharded, generate_load, response_snapshot):
|
||||||
|
@ -13,6 +13,7 @@ async def t5_sharded(t5_sharded_handle):
|
|||||||
return t5_sharded_handle.client
|
return t5_sharded_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_t5_sharded(t5_sharded, response_snapshot):
|
async def test_t5_sharded(t5_sharded, response_snapshot):
|
||||||
response = await t5_sharded.generate(
|
response = await t5_sharded.generate(
|
||||||
@ -24,6 +25,7 @@ async def test_t5_sharded(t5_sharded, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_t5_sharded_load(t5_sharded, generate_load, response_snapshot):
|
async def test_t5_sharded_load(t5_sharded, generate_load, response_snapshot):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
|
@ -413,6 +413,9 @@ struct Args {
|
|||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
|
|
||||||
|
#[clap(default_value = "text-generation-inference.router", long, env)]
|
||||||
|
otlp_service_name: String,
|
||||||
|
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
cors_allow_origin: Vec<String>,
|
cors_allow_origin: Vec<String>,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
@ -489,6 +492,7 @@ fn shard_manager(
|
|||||||
max_input_tokens: usize,
|
max_input_tokens: usize,
|
||||||
lora_adapters: Option<String>,
|
lora_adapters: Option<String>,
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
|
otlp_service_name: String,
|
||||||
log_level: LevelFilter,
|
log_level: LevelFilter,
|
||||||
status_sender: mpsc::Sender<ShardStatus>,
|
status_sender: mpsc::Sender<ShardStatus>,
|
||||||
shutdown: Arc<AtomicBool>,
|
shutdown: Arc<AtomicBool>,
|
||||||
@ -554,12 +558,16 @@ fn shard_manager(
|
|||||||
(None, Some(factor)) => Some((RopeScaling::Linear, factor)),
|
(None, Some(factor)) => Some((RopeScaling::Linear, factor)),
|
||||||
};
|
};
|
||||||
|
|
||||||
// OpenTelemetry
|
// OpenTelemetry Endpoint
|
||||||
if let Some(otlp_endpoint) = otlp_endpoint {
|
if let Some(otlp_endpoint) = otlp_endpoint {
|
||||||
shard_args.push("--otlp-endpoint".to_string());
|
shard_args.push("--otlp-endpoint".to_string());
|
||||||
shard_args.push(otlp_endpoint);
|
shard_args.push(otlp_endpoint);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OpenTelemetry Service Name
|
||||||
|
shard_args.push("--otlp-service-name".to_string());
|
||||||
|
shard_args.push(otlp_service_name);
|
||||||
|
|
||||||
// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
|
// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
|
||||||
shard_args.push("--max-input-tokens".to_string());
|
shard_args.push("--max-input-tokens".to_string());
|
||||||
shard_args.push(max_input_tokens.to_string());
|
shard_args.push(max_input_tokens.to_string());
|
||||||
@ -598,7 +606,7 @@ fn shard_manager(
|
|||||||
|
|
||||||
// Parse Inference API token
|
// Parse Inference API token
|
||||||
if let Ok(api_token) = env::var("HF_API_TOKEN") {
|
if let Ok(api_token) = env::var("HF_API_TOKEN") {
|
||||||
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
envs.push(("HF_TOKEN".into(), api_token.into()))
|
||||||
};
|
};
|
||||||
|
|
||||||
// Detect rope scaling
|
// Detect rope scaling
|
||||||
@ -762,7 +770,10 @@ fn shutdown_shards(shutdown: Arc<AtomicBool>, shutdown_receiver: &mpsc::Receiver
|
|||||||
fn num_cuda_devices() -> Option<usize> {
|
fn num_cuda_devices() -> Option<usize> {
|
||||||
let devices = match env::var("CUDA_VISIBLE_DEVICES") {
|
let devices = match env::var("CUDA_VISIBLE_DEVICES") {
|
||||||
Ok(devices) => devices,
|
Ok(devices) => devices,
|
||||||
Err(_) => env::var("NVIDIA_VISIBLE_DEVICES").ok()?,
|
Err(_) => match env::var("NVIDIA_VISIBLE_DEVICES") {
|
||||||
|
Ok(devices) => devices,
|
||||||
|
Err(_) => env::var("ZE_AFFINITY_MASK").ok()?,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
let n_devices = devices.split(',').count();
|
let n_devices = devices.split(',').count();
|
||||||
Some(n_devices)
|
Some(n_devices)
|
||||||
@ -835,9 +846,9 @@ fn find_num_shards(
|
|||||||
let num_shard = match (sharded, num_shard) {
|
let num_shard = match (sharded, num_shard) {
|
||||||
(Some(true), None) => {
|
(Some(true), None) => {
|
||||||
// try to default to the number of available GPUs
|
// try to default to the number of available GPUs
|
||||||
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES");
|
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES/ZE_AFFINITY_MASK");
|
||||||
let n_devices = num_cuda_devices()
|
let n_devices = num_cuda_devices()
|
||||||
.expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set");
|
.expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES/ZE_AFFINITY_MASK are not set");
|
||||||
if n_devices <= 1 {
|
if n_devices <= 1 {
|
||||||
return Err(LauncherError::NotEnoughCUDADevices(format!(
|
return Err(LauncherError::NotEnoughCUDADevices(format!(
|
||||||
"`sharded` is true but only found {n_devices} CUDA devices"
|
"`sharded` is true but only found {n_devices} CUDA devices"
|
||||||
@ -936,7 +947,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
|||||||
|
|
||||||
// Parse Inference API token
|
// Parse Inference API token
|
||||||
if let Ok(api_token) = env::var("HF_API_TOKEN") {
|
if let Ok(api_token) = env::var("HF_API_TOKEN") {
|
||||||
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
envs.push(("HF_TOKEN".into(), api_token.into()))
|
||||||
};
|
};
|
||||||
|
|
||||||
// If args.weights_cache_override is some, pass it to the download process
|
// If args.weights_cache_override is some, pass it to the download process
|
||||||
@ -1046,6 +1057,7 @@ fn spawn_shards(
|
|||||||
let shutdown = shutdown.clone();
|
let shutdown = shutdown.clone();
|
||||||
let shutdown_sender = shutdown_sender.clone();
|
let shutdown_sender = shutdown_sender.clone();
|
||||||
let otlp_endpoint = args.otlp_endpoint.clone();
|
let otlp_endpoint = args.otlp_endpoint.clone();
|
||||||
|
let otlp_service_name = args.otlp_service_name.clone();
|
||||||
let quantize = args.quantize;
|
let quantize = args.quantize;
|
||||||
let speculate = args.speculate;
|
let speculate = args.speculate;
|
||||||
let dtype = args.dtype;
|
let dtype = args.dtype;
|
||||||
@ -1087,6 +1099,7 @@ fn spawn_shards(
|
|||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
lora_adapters,
|
lora_adapters,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
|
otlp_service_name,
|
||||||
max_log_level,
|
max_log_level,
|
||||||
status_sender,
|
status_sender,
|
||||||
shutdown,
|
shutdown,
|
||||||
@ -1220,6 +1233,11 @@ fn spawn_webserver(
|
|||||||
router_args.push(otlp_endpoint);
|
router_args.push(otlp_endpoint);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OpenTelemetry
|
||||||
|
let otlp_service_name = args.otlp_service_name;
|
||||||
|
router_args.push("--otlp-service-name".to_string());
|
||||||
|
router_args.push(otlp_service_name);
|
||||||
|
|
||||||
// CORS origins
|
// CORS origins
|
||||||
for origin in args.cors_allow_origin.into_iter() {
|
for origin in args.cors_allow_origin.into_iter() {
|
||||||
router_args.push("--cors-allow-origin".to_string());
|
router_args.push("--cors-allow-origin".to_string());
|
||||||
@ -1240,7 +1258,7 @@ fn spawn_webserver(
|
|||||||
|
|
||||||
// Parse Inference API token
|
// Parse Inference API token
|
||||||
if let Ok(api_token) = env::var("HF_API_TOKEN") {
|
if let Ok(api_token) = env::var("HF_API_TOKEN") {
|
||||||
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
envs.push(("HF_TOKEN".into(), api_token.into()))
|
||||||
};
|
};
|
||||||
|
|
||||||
// Parse Compute type
|
// Parse Compute type
|
||||||
|
@ -576,7 +576,7 @@ impl ChatCompletion {
|
|||||||
};
|
};
|
||||||
Self {
|
Self {
|
||||||
id: String::new(),
|
id: String::new(),
|
||||||
object: "text_completion".into(),
|
object: "chat.completion".into(),
|
||||||
created,
|
created,
|
||||||
model,
|
model,
|
||||||
system_fingerprint,
|
system_fingerprint,
|
||||||
@ -688,7 +688,7 @@ impl ChatCompletionChunk {
|
|||||||
};
|
};
|
||||||
Self {
|
Self {
|
||||||
id: String::new(),
|
id: String::new(),
|
||||||
object: "text_completion".to_string(),
|
object: "chat.completion.chunk".to_string(),
|
||||||
created,
|
created,
|
||||||
model,
|
model,
|
||||||
system_fingerprint,
|
system_fingerprint,
|
||||||
|
@ -65,6 +65,8 @@ struct Args {
|
|||||||
json_output: bool,
|
json_output: bool,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
|
#[clap(default_value = "text-generation-inference.router", long, env)]
|
||||||
|
otlp_service_name: String,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
cors_allow_origin: Option<Vec<String>>,
|
cors_allow_origin: Option<Vec<String>>,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
@ -107,6 +109,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
validation_workers,
|
validation_workers,
|
||||||
json_output,
|
json_output,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
|
otlp_service_name,
|
||||||
cors_allow_origin,
|
cors_allow_origin,
|
||||||
ngrok,
|
ngrok,
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
@ -117,7 +120,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
init_logging(otlp_endpoint, json_output);
|
init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||||
|
|
||||||
// Validate args
|
// Validate args
|
||||||
if max_input_tokens >= max_total_tokens {
|
if max_input_tokens >= max_total_tokens {
|
||||||
@ -156,7 +159,9 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Parse Huggingface hub token
|
// Parse Huggingface hub token
|
||||||
let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();
|
let authorization_token = std::env::var("HF_TOKEN")
|
||||||
|
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
||||||
|
.ok();
|
||||||
|
|
||||||
// Tokenizer instance
|
// Tokenizer instance
|
||||||
// This will only be used to validate payloads
|
// This will only be used to validate payloads
|
||||||
@ -367,10 +372,11 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
|
|
||||||
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
|
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
|
||||||
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
|
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
|
||||||
|
/// - otlp_service_name service name to appear in APM
|
||||||
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
|
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
|
||||||
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
|
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
|
||||||
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
|
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
|
||||||
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
|
||||||
let mut layers = Vec::new();
|
let mut layers = Vec::new();
|
||||||
|
|
||||||
// STDOUT/STDERR layer
|
// STDOUT/STDERR layer
|
||||||
@ -401,7 +407,7 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
|||||||
trace::config()
|
trace::config()
|
||||||
.with_resource(Resource::new(vec![KeyValue::new(
|
.with_resource(Resource::new(vec![KeyValue::new(
|
||||||
"service.name",
|
"service.name",
|
||||||
"text-generation-inference.router",
|
otlp_service_name,
|
||||||
)]))
|
)]))
|
||||||
.with_sampler(Sampler::AlwaysOn),
|
.with_sampler(Sampler::AlwaysOn),
|
||||||
)
|
)
|
||||||
|
1152
server/tests/utils/test_weights.py
Normal file
1152
server/tests/utils/test_weights.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -42,6 +42,7 @@ def serve(
|
|||||||
logger_level: str = "INFO",
|
logger_level: str = "INFO",
|
||||||
json_output: bool = False,
|
json_output: bool = False,
|
||||||
otlp_endpoint: Optional[str] = None,
|
otlp_endpoint: Optional[str] = None,
|
||||||
|
otlp_service_name: str = "text-generation-inference.server",
|
||||||
max_input_tokens: Optional[int] = None,
|
max_input_tokens: Optional[int] = None,
|
||||||
):
|
):
|
||||||
if sharded:
|
if sharded:
|
||||||
@ -76,7 +77,7 @@ def serve(
|
|||||||
|
|
||||||
# Setup OpenTelemetry distributed tracing
|
# Setup OpenTelemetry distributed tracing
|
||||||
if otlp_endpoint is not None:
|
if otlp_endpoint is not None:
|
||||||
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
|
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
|
||||||
|
|
||||||
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)
|
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ if SYSTEM == "cuda":
|
|||||||
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "ipex":
|
||||||
from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||||
else:
|
else:
|
||||||
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
import torch
|
import torch
|
||||||
|
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
||||||
|
|
||||||
SUPPORTS_WINDOWING = False
|
SUPPORTS_WINDOWING = False
|
||||||
|
|
||||||
@ -56,8 +57,6 @@ def paged_attention(
|
|||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
):
|
):
|
||||||
query = query.contiguous()
|
|
||||||
block_size = value_cache.shape[3]
|
|
||||||
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||||
out,
|
out,
|
||||||
query,
|
query,
|
||||||
@ -67,7 +66,7 @@ def paged_attention(
|
|||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
block_size,
|
BLOCK_SIZE,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
@ -82,18 +82,20 @@ elif SYSTEM == "rocm":
|
|||||||
|
|
||||||
return super().forward(hidden_states), residual
|
return super().forward(hidden_states), residual
|
||||||
|
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "ipex":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
res_out = hidden_states
|
|
||||||
out = ipex.llm.functional.add_layer_norm(
|
out = ipex.llm.functional.add_layer_norm(
|
||||||
residual, hidden_states, self.weight, self.bias, self.eps, True
|
residual,
|
||||||
|
hidden_states,
|
||||||
|
self.weight,
|
||||||
|
self.bias,
|
||||||
|
self.eps,
|
||||||
|
residual is not None,
|
||||||
)
|
)
|
||||||
if residual is not None:
|
return out, residual if residual is not None else hidden_states
|
||||||
res_out = residual
|
|
||||||
return out, res_out
|
|
||||||
|
|
||||||
|
|
||||||
class FastRMSNorm(nn.Module):
|
class FastRMSNorm(nn.Module):
|
||||||
@ -109,19 +111,16 @@ class FastRMSNorm(nn.Module):
|
|||||||
return cls(weight, eps)
|
return cls(weight, eps)
|
||||||
|
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if SYSTEM == "xpu":
|
if SYSTEM == "ipex":
|
||||||
residual_out = hidden_states
|
|
||||||
out = ipex.llm.functional.add_rms_norm(
|
out = ipex.llm.functional.add_rms_norm(
|
||||||
residual,
|
residual,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
self.weight,
|
self.weight,
|
||||||
None,
|
None,
|
||||||
self.variance_epsilon,
|
self.variance_epsilon,
|
||||||
True,
|
residual is not None,
|
||||||
)
|
)
|
||||||
if residual is not None:
|
return out, residual if residual is not None else hidden_states
|
||||||
residual_out = residual
|
|
||||||
return out, residual_out
|
|
||||||
elif hidden_states.shape[-1] > 8192:
|
elif hidden_states.shape[-1] > 8192:
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
|
@ -9,7 +9,7 @@ if SYSTEM == "cuda":
|
|||||||
import rotary_emb
|
import rotary_emb
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "ipex":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
|
||||||
@ -69,7 +69,7 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
|
|
||||||
# Inplace operation, updating query and key.
|
# Inplace operation, updating query and key.
|
||||||
ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "ipex":
|
||||||
ipex.llm.functional.rotary_embedding(
|
ipex.llm.functional.rotary_embedding(
|
||||||
query, key, sin, cos, query.size(-1), True
|
query, key, sin, cos, query.size(-1), True
|
||||||
)
|
)
|
||||||
|
@ -3,6 +3,10 @@ from torch.nn import functional as F
|
|||||||
from typing import Iterable, List
|
from typing import Iterable, List
|
||||||
from text_generation_server.layers.linear import get_linear, FastLinear
|
from text_generation_server.layers.linear import get_linear, FastLinear
|
||||||
from text_generation_server.layers.exl2 import Exl2Weight
|
from text_generation_server.layers.exl2 import Exl2Weight
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
if SYSTEM == "ipex":
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
|
||||||
class LayerConcat(torch.nn.Module):
|
class LayerConcat(torch.nn.Module):
|
||||||
@ -96,10 +100,14 @@ class TensorParallelHead(SuperLayer):
|
|||||||
local_out = gather_input.T
|
local_out = gather_input.T
|
||||||
|
|
||||||
torch.mm(input, self.linear.weight.T, out=local_out)
|
torch.mm(input, self.linear.weight.T, out=local_out)
|
||||||
|
if SYSTEM == "ipex":
|
||||||
torch.distributed.all_gather_into_tensor(
|
ipex.distributed.all_gather_into_tensor(
|
||||||
world_out, gather_input, group=self.process_group
|
world_out, gather_input, group=self.process_group
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
torch.distributed.all_gather_into_tensor(
|
||||||
|
world_out, gather_input, group=self.process_group
|
||||||
|
)
|
||||||
|
|
||||||
if input.shape[0] == 1:
|
if input.shape[0] == 1:
|
||||||
return world_out
|
return world_out
|
||||||
@ -109,7 +117,10 @@ class TensorParallelHead(SuperLayer):
|
|||||||
world_output = [
|
world_output = [
|
||||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||||
]
|
]
|
||||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
if SYSTEM == "ipex":
|
||||||
|
ipex.distributed.all_gather(world_output, output, group=self.process_group)
|
||||||
|
else:
|
||||||
|
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||||
world_output = torch.cat(world_output, dim=-1)
|
world_output = torch.cat(world_output, dim=-1)
|
||||||
return world_output
|
return world_output
|
||||||
|
|
||||||
@ -206,7 +217,10 @@ class TensorParallelRowLinear(SuperLayer):
|
|||||||
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
||||||
out = super().forward(input)
|
out = super().forward(input)
|
||||||
if self.process_group.size() > 1 and reduce:
|
if self.process_group.size() > 1 and reduce:
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
if SYSTEM == "ipex":
|
||||||
|
ipex.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
else:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@ -243,5 +257,8 @@ class TensorParallelEmbedding(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
out = torch.nn.functional.embedding(input, self.weight)
|
out = torch.nn.functional.embedding(input, self.weight)
|
||||||
if self.reduce and self.process_group.size() > 1:
|
if self.reduce and self.process_group.size() > 1:
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
if SYSTEM == "ipex":
|
||||||
|
ipex.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
else:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
return out
|
return out
|
||||||
|
@ -22,7 +22,7 @@ from transformers.configuration_utils import PretrainedConfig
|
|||||||
from typing import Optional, List, Tuple, Any
|
from typing import Optional, List, Tuple, Any
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if SYSTEM != "xpu":
|
if SYSTEM != "ipex":
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
|
@ -26,7 +26,7 @@ import numpy as np
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if SYSTEM != "xpu":
|
if SYSTEM != "ipex":
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
@ -847,26 +847,43 @@ class FlashCausalLM(Model):
|
|||||||
empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
if SYSTEM == "xpu":
|
if SYSTEM == "ipex" and device.type == "xpu":
|
||||||
x = 1
|
x = 1
|
||||||
else:
|
else:
|
||||||
x = BLOCK_SIZE // element_size
|
x = BLOCK_SIZE // element_size
|
||||||
|
|
||||||
self.kv_cache = [
|
if SYSTEM == "ipex" and device == torch.device("cpu"):
|
||||||
(
|
self.kv_cache = [
|
||||||
torch.empty(
|
(
|
||||||
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
|
torch.empty(
|
||||||
dtype=dtype,
|
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||||
device=device,
|
dtype=dtype,
|
||||||
),
|
device=device,
|
||||||
torch.empty(
|
),
|
||||||
(num_blocks, num_heads, head_size, BLOCK_SIZE),
|
torch.empty(
|
||||||
dtype=dtype,
|
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||||
device=device,
|
dtype=dtype,
|
||||||
),
|
device=device,
|
||||||
)
|
),
|
||||||
for _ in range(num_layers)
|
)
|
||||||
]
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
self.kv_cache = [
|
||||||
|
(
|
||||||
|
torch.empty(
|
||||||
|
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
torch.empty(
|
||||||
|
(num_blocks, num_heads, head_size, BLOCK_SIZE),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
|
||||||
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||||
|
@ -34,9 +34,13 @@ class FlashGPT2(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "ipex":
|
||||||
device = torch.device(f"xpu:{rank}")
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashGPT2 is only available on GPU")
|
raise NotImplementedError("FlashGPT2 is only available on GPU")
|
||||||
|
|
||||||
|
@ -48,9 +48,13 @@ class FlashLlama(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "ipex":
|
||||||
device = torch.device(f"xpu:{rank}")
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||||
|
|
||||||
|
@ -50,9 +50,13 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "ipex":
|
||||||
device = torch.device(f"xpu:{rank}")
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashMistral is only available on GPU")
|
raise NotImplementedError("FlashMistral is only available on GPU")
|
||||||
|
|
||||||
|
@ -33,9 +33,13 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "ipex":
|
||||||
device = torch.device(f"xpu:{rank}")
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashNeoX is only available on GPU")
|
raise NotImplementedError("FlashNeoX is only available on GPU")
|
||||||
|
|
||||||
|
@ -34,9 +34,13 @@ class FlashRWSharded(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "ipex":
|
||||||
device = torch.device(f"xpu:{rank}")
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashRW is only available on GPU")
|
raise NotImplementedError("FlashRW is only available on GPU")
|
||||||
|
|
||||||
|
@ -37,9 +37,13 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "ipex":
|
||||||
device = torch.device(f"xpu:{rank}")
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
||||||
|
|
||||||
|
@ -54,10 +54,8 @@ class UDSOpenTelemetryAioServerInterceptor(OpenTelemetryAioServerInterceptor):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def setup_tracing(shard: int, otlp_endpoint: str):
|
def setup_tracing(otlp_service_name: str, otlp_endpoint: str):
|
||||||
resource = Resource.create(
|
resource = Resource.create(attributes={"service.name": otlp_service_name})
|
||||||
attributes={"service.name": f"text-generation-inference.server-{shard}"}
|
|
||||||
)
|
|
||||||
span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)
|
span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)
|
||||||
span_processor = BatchSpanProcessor(span_exporter)
|
span_processor = BatchSpanProcessor(span_exporter)
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ import torch
|
|||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
# Tensor Parallelism settings
|
# Tensor Parallelism settings
|
||||||
RANK = int(os.getenv("RANK", "0"))
|
RANK = int(os.getenv("RANK", "0"))
|
||||||
@ -57,14 +58,7 @@ def initialize_torch_distributed():
|
|||||||
options.is_high_priority_stream = True
|
options.is_high_priority_stream = True
|
||||||
options._timeout = timedelta(seconds=60)
|
options._timeout = timedelta(seconds=60)
|
||||||
else:
|
else:
|
||||||
try:
|
backend = "gloo"
|
||||||
import oneccl_bindings_for_pytorch
|
|
||||||
|
|
||||||
backend = "ccl"
|
|
||||||
if os.getenv("CCL_WORKER_COUNT", None) is None:
|
|
||||||
os.environ["CCL_WORKER_COUNT"] = str(1)
|
|
||||||
except ImportError:
|
|
||||||
backend = "gloo"
|
|
||||||
options = None
|
options = None
|
||||||
|
|
||||||
if WORLD_SIZE == 1:
|
if WORLD_SIZE == 1:
|
||||||
@ -75,13 +69,24 @@ def initialize_torch_distributed():
|
|||||||
|
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
# Call the init process.
|
# Call the init process.
|
||||||
torch.distributed.init_process_group(
|
if SYSTEM == "ipex":
|
||||||
backend=backend,
|
import intel_extension_for_pytorch as ipex
|
||||||
world_size=WORLD_SIZE,
|
|
||||||
rank=RANK,
|
ipex.distributed.init_process_group(
|
||||||
timeout=timedelta(seconds=60),
|
backend="ccl",
|
||||||
pg_options=options,
|
world_size=WORLD_SIZE,
|
||||||
)
|
rank=RANK,
|
||||||
|
timeout=timedelta(seconds=60),
|
||||||
|
pg_options=options,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend=backend,
|
||||||
|
world_size=WORLD_SIZE,
|
||||||
|
rank=RANK,
|
||||||
|
timeout=timedelta(seconds=60),
|
||||||
|
pg_options=options,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("torch.distributed is already initialized.")
|
logger.warning("torch.distributed is already initialized.")
|
||||||
|
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
|
||||||
def is_xpu_available():
|
def is_ipex_available():
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch
|
import intel_extension_for_pytorch
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return False
|
return False
|
||||||
|
return True
|
||||||
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
|
||||||
|
|
||||||
|
|
||||||
def get_cuda_free_memory(device, memory_fraction):
|
def get_cuda_free_memory(device, memory_fraction):
|
||||||
@ -19,11 +19,28 @@ def get_cuda_free_memory(device, memory_fraction):
|
|||||||
|
|
||||||
|
|
||||||
def get_xpu_free_memory(device, memory_fraction):
|
def get_xpu_free_memory(device, memory_fraction):
|
||||||
total_gpu_memory = torch.xpu.get_device_properties(device).total_memory
|
total_memory = torch.xpu.get_device_properties(device).total_memory
|
||||||
free_memory = int(total_gpu_memory * 0.5)
|
device_id = device.index
|
||||||
|
query = f"xpu-smi dump -d {device_id} -m 18 -n 1"
|
||||||
|
output = subprocess.check_output(query.split()).decode("utf-8").split("\n")
|
||||||
|
used_memory = float(output[1].split(",")[-1]) * 1024 * 1024
|
||||||
|
free_memory = int(total_memory * 0.95 - used_memory)
|
||||||
return free_memory
|
return free_memory
|
||||||
|
|
||||||
|
|
||||||
|
def get_cpu_free_memory(device, memory_fraction):
|
||||||
|
import psutil
|
||||||
|
from text_generation_server.utils.dist import WORLD_SIZE
|
||||||
|
|
||||||
|
mem = psutil.virtual_memory()
|
||||||
|
free_memory = int(mem.available * 0.95 / WORLD_SIZE)
|
||||||
|
return free_memory
|
||||||
|
|
||||||
|
|
||||||
|
def noop(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
SYSTEM = None
|
SYSTEM = None
|
||||||
if torch.version.hip is not None:
|
if torch.version.hip is not None:
|
||||||
SYSTEM = "rocm"
|
SYSTEM = "rocm"
|
||||||
@ -35,18 +52,20 @@ elif torch.version.cuda is not None and torch.cuda.is_available():
|
|||||||
empty_cache = torch.cuda.empty_cache
|
empty_cache = torch.cuda.empty_cache
|
||||||
synchronize = torch.cuda.synchronize
|
synchronize = torch.cuda.synchronize
|
||||||
get_free_memory = get_cuda_free_memory
|
get_free_memory = get_cuda_free_memory
|
||||||
elif is_xpu_available():
|
elif is_ipex_available():
|
||||||
SYSTEM = "xpu"
|
SYSTEM = "ipex"
|
||||||
empty_cache = torch.xpu.empty_cache
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
synchronize = torch.xpu.synchronize
|
empty_cache = torch.xpu.empty_cache
|
||||||
get_free_memory = get_xpu_free_memory
|
synchronize = torch.xpu.synchronize
|
||||||
|
get_free_memory = get_xpu_free_memory
|
||||||
|
else:
|
||||||
|
empty_cache = noop
|
||||||
|
synchronize = noop
|
||||||
|
get_free_memory = get_cpu_free_memory
|
||||||
else:
|
else:
|
||||||
SYSTEM = "cpu"
|
SYSTEM = "cpu"
|
||||||
|
|
||||||
def noop(*args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
empty_cache = noop
|
empty_cache = noop
|
||||||
synchronize = noop
|
synchronize = noop
|
||||||
get_free_memory = noop
|
get_free_memory = get_cpu_free_memory
|
||||||
logger.info(f"Detected system {SYSTEM}")
|
logger.info(f"Detected system {SYSTEM}")
|
||||||
|
Loading…
Reference in New Issue
Block a user