Merge branch 'main' into gpt_awq_4

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2024-10-14 20:28:15 -07:00
commit 7c6230c59a
45 changed files with 1403 additions and 344 deletions

View File

@ -21,9 +21,11 @@ jobs:
build-and-push: build-and-push:
outputs: outputs:
docker_image: ${{ steps.final.outputs.docker_image }} docker_image: ${{ steps.final.outputs.docker_image }}
docker_volume: ${{ steps.final.outputs.docker_volume }}
docker_devices: ${{ steps.final.outputs.docker_devices }} docker_devices: ${{ steps.final.outputs.docker_devices }}
runs_on: ${{ steps.final.outputs.runs_on }} runs_on: ${{ steps.final.outputs.runs_on }}
label: ${{ steps.final.outputs.label }} label: ${{ steps.final.outputs.label }}
extra_pytest: ${{ steps.final.outputs.extra_pytest }}
concurrency: concurrency:
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }} group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true cancel-in-progress: true
@ -44,32 +46,39 @@ jobs:
cuda) cuda)
export dockerfile="Dockerfile" export dockerfile="Dockerfile"
export label_extension="" export label_extension=""
export docker_volume="/mnt/cache"
export docker_devices="" export docker_devices=""
export runs_on="aws-g6-12xl-plus-priv-cache" export runs_on="aws-g6-12xl-plus-priv-cache"
export platform="" export platform=""
export extra_pytest=""
;; ;;
rocm) rocm)
export dockerfile="Dockerfile_amd" export dockerfile="Dockerfile_amd"
export label_extension="-rocm" export label_extension="-rocm"
export docker_devices="/dev/kfd,/dev/dri" export docker_devices="/dev/kfd,/dev/dri"
# TODO Re-enable when they pass. export docker_volume="/mnt"
# export runs_on="amd-gpu-tgi" export runs_on="amd-gpu-runners"
export runs_on="ubuntu-latest"
export platform="" export platform=""
export extra_pytest="-k test_flash_gemma_gptq_load"
;; ;;
intel-xpu) intel-xpu)
export dockerfile="Dockerfile_intel" export dockerfile="Dockerfile_intel"
export label_extension="-intel-xpu" export label_extension="-intel-xpu"
export docker_devices="" export docker_devices=""
export docker_volume="/mnt/cache"
export runs_on="ubuntu-latest" export runs_on="ubuntu-latest"
export platform="xpu" export platform="xpu"
export extra_pytest=""
;; ;;
intel-cpu) intel-cpu)
export dockerfile="Dockerfile_intel" export dockerfile="Dockerfile_intel"
export label_extension="-intel-cpu" export label_extension="-intel-cpu"
export docker_devices="" export docker_devices="none"
export runs_on="ubuntu-latest" export docker_volume="/mnt/cache"
# export runs_on="ubuntu-latest"
export runs_on="aws-highmemory-32-plus-priv"
export platform="cpu" export platform="cpu"
export extra_pytest="-k test_flash_gemma_simple"
;; ;;
esac esac
echo $dockerfile echo $dockerfile
@ -81,8 +90,10 @@ jobs:
echo "DOCKERFILE=${dockerfile}" >> $GITHUB_ENV echo "DOCKERFILE=${dockerfile}" >> $GITHUB_ENV
echo "LABEL=${label_extension}" >> $GITHUB_ENV echo "LABEL=${label_extension}" >> $GITHUB_ENV
echo "PLATFORM=${platform}" >> $GITHUB_ENV echo "PLATFORM=${platform}" >> $GITHUB_ENV
echo "DOCKER_VOLUME=${docker_volume}" >> $GITHUB_ENV
echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV
echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV
echo "EXTRA_PYTEST=${extra_pytest}" >> $GITHUB_ENV
echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV
- name: Initialize Docker Buildx - name: Initialize Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
@ -157,16 +168,18 @@ jobs:
run: | run: |
echo "docker_image=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT" echo "docker_image=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT" echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT"
echo "docker_volume=${{ env.DOCKER_VOLUME }}" >> "$GITHUB_OUTPUT"
echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT" echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT"
echo "label=${{ env.LABEL }}" >> "$GITHUB_OUTPUT" echo "label=${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
echo "extra_pytest=${{ env.EXTRA_PYTEST }}" >> "$GITHUB_OUTPUT"
integration_tests: integration_tests:
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label }}-${{ github.head_ref || github.run_id }} group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true cancel-in-progress: true
needs: build-and-push needs: build-and-push
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
runs-on: runs-on:
group: ${{ needs.build-and-push.outputs.runs_on }} group: ${{ needs.build-and-push.outputs.runs_on }}
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
env: env:
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '--release' }} PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '--release' }}
steps: steps:
@ -177,15 +190,16 @@ jobs:
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:
python-version: "3.10" python-version: "3.11"
- name: Install - name: Install
run: | run: |
make install-integration-tests make install-integration-tests
- name: Run tests - name: Run tests
run: | run: |
export DOCKER_VOLUME=/mnt/cache export DOCKER_VOLUME=${{ needs.build-and-push.outputs.docker_volume }}
export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }} export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}
export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }} export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }}
export EXTRA_PYTEST="${{ needs.build-and-push.outputs.extra_pytest }}"
export HF_TOKEN=${{ secrets.HF_TOKEN }} export HF_TOKEN=${{ secrets.HF_TOKEN }}
echo $DOCKER_IMAGE echo $DOCKER_IMAGE
pytest -s -vv integration-tests ${PYTEST_FLAGS} pytest -s -vv integration-tests ${PYTEST_FLAGS} ${EXTRA_PYTEST}

View File

@ -1,5 +1,5 @@
# Rust builder # Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -32,6 +32,7 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json
ARG GIT_SHA ARG GIT_SHA
ARG DOCKER_LABEL ARG DOCKER_LABEL
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto COPY proto proto
@ -39,7 +40,7 @@ COPY benchmark benchmark
COPY router router COPY router router
COPY backends backends COPY backends backends
COPY launcher launcher COPY launcher launcher
RUN cargo build --profile release-opt RUN cargo build --profile release-opt --frozen
# Python builder # Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile

View File

@ -1,5 +1,5 @@
# Rust builder # Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -31,6 +31,7 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json
ARG GIT_SHA ARG GIT_SHA
ARG DOCKER_LABEL ARG DOCKER_LABEL
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto COPY proto proto
@ -38,7 +39,7 @@ COPY benchmark benchmark
COPY router router COPY router router
COPY backends backends COPY backends backends
COPY launcher launcher COPY launcher launcher
RUN cargo build --profile release-opt RUN cargo build --profile release-opt --frozen
# Text Generation Inference base image for RoCm # Text Generation Inference base image for RoCm
FROM rocm/dev-ubuntu-22.04:6.2 AS base FROM rocm/dev-ubuntu-22.04:6.2 AS base

View File

@ -1,6 +1,6 @@
ARG PLATFORM=xpu ARG PLATFORM=xpu
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -32,6 +32,7 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json
ARG GIT_SHA ARG GIT_SHA
ARG DOCKER_LABEL ARG DOCKER_LABEL
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto COPY proto proto
@ -39,7 +40,7 @@ COPY benchmark benchmark
COPY router router COPY router router
COPY backends backends COPY backends backends
COPY launcher launcher COPY launcher launcher
RUN cargo build --profile release-opt RUN cargo build --profile release-opt --frozen
# Text Generation Inference base image for Intel # Text Generation Inference base image for Intel
@ -52,7 +53,7 @@ ARG MAMBA_VERSION=23.1.0-1
ARG PYTHON_VERSION='3.11.10' ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx # Automatically set by buildx
ARG TARGETPLATFORM ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH ENV PATH=/opt/conda/bin:$PATH
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. # TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba # Install mamba
@ -111,6 +112,8 @@ ENV PATH=/opt/conda/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/
ENV CCL_ZE_IPC_EXCHANGE=sockets ENV CCL_ZE_IPC_EXCHANGE=sockets
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
ENV TORCH_LLM_ALLREDUCE=1
ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
@ -176,15 +179,17 @@ RUN conda install -c conda-forge gperftools mkl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.5.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.5.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.20.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.20.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install triton py-libnuma RUN pip install triton py-libnuma
WORKDIR /usr/src WORKDIR /usr/src
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout f86e93e4890dc2c989024d148d415c9aa8a1649f RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout f86e93e4890dc2c989024d148d415c9aa8a1649f
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout v2.4.0+cpu+rc0 RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout v2.4.0+cpu+rc0
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so
ENV CCL_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch ENV CCL_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch

View File

@ -120,7 +120,7 @@ curl localhost:3000/v1/chat/completions \
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. **Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.2.0-rocm --model-id $model` instead of the command above. **Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1-rocm --model-id $model` instead of the command above.
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli): To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
``` ```
@ -150,7 +150,7 @@ model=meta-llama/Meta-Llama-3.1-8B-Instruct
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
token=<your cli READ token> token=<your cli READ token>
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model
``` ```
### A note on Shared Memory (shm) ### A note on Shared Memory (shm)

View File

@ -27,3 +27,6 @@ asyncio_mode = "auto"
[build-system] [build-system]
requires = ["poetry-core>=1.0.0"] requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.isort]
profile = "black"

View File

@ -2114,12 +2114,18 @@
"ToolType": { "ToolType": {
"oneOf": [ "oneOf": [
{ {
"type": "object", "type": "string",
"default": null, "description": "Means the model can pick between generating a message or calling one or more tools.",
"nullable": true "enum": [
"auto"
]
}, },
{ {
"type": "string" "type": "string",
"description": "Means the model will not call any tool and instead generates a message.",
"enum": [
"none"
]
}, },
{ {
"type": "object", "type": "object",
@ -2131,13 +2137,10 @@
"$ref": "#/components/schemas/FunctionName" "$ref": "#/components/schemas/FunctionName"
} }
} }
},
{
"type": "object",
"default": null,
"nullable": true
} }
] ],
"description": "Controls which (if any) tool is called by the model.",
"example": "auto"
}, },
"Url": { "Url": {
"type": "object", "type": "object",
@ -2183,4 +2186,4 @@
"description": "Hugging Face Text Generation Inference API" "description": "Hugging Face Text Generation Inference API"
} }
] ]
} }

View File

@ -3,6 +3,8 @@
title: Text Generation Inference title: Text Generation Inference
- local: quicktour - local: quicktour
title: Quick Tour title: Quick Tour
- local: supported_models
title: Supported Models
- local: installation_nvidia - local: installation_nvidia
title: Using TGI with Nvidia GPUs title: Using TGI with Nvidia GPUs
- local: installation_amd - local: installation_amd
@ -15,8 +17,7 @@
title: Using TGI with Intel GPUs title: Using TGI with Intel GPUs
- local: installation - local: installation
title: Installation from source title: Installation from source
- local: supported_models
title: Supported Models and Hardware
- local: architecture - local: architecture
title: Internal Architecture title: Internal Architecture
- local: usage_statistics - local: usage_statistics

View File

@ -19,6 +19,6 @@ docker run --gpus all \
--shm-size 1g \ --shm-size 1g \
-e HF_TOKEN=$token \ -e HF_TOKEN=$token \
-p 8080:80 \ -p 8080:80 \
-v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.4 \ -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 \
--model-id $model --model-id $model
``` ```

View File

@ -19,7 +19,7 @@ bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models.
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇 In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
```bash ```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model --quantize bitsandbytes
``` ```
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load. 4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
@ -27,7 +27,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇 In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
```bash ```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes-nf4 docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model --quantize bitsandbytes-nf4
``` ```
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes). You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
@ -48,7 +48,7 @@ $$({\hat{W}_{l}}^{*} = argmin_{\hat{W_{l}}} ||W_{l}X-\hat{W}_{l}X||^{2}_{2})$$
TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇 TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇
```bash ```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize gptq docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model --quantize gptq
``` ```
Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI. Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI.

View File

@ -17,8 +17,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
<Tip> <Tip>
If you want to serve gated or private models, which provide If you want to serve gated or private models, please refer to
controlled access to sensitive or proprietary content, refer to
[this guide](https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/gated_model_access) [this guide](https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/gated_model_access)
for detailed instructions. for detailed instructions.
@ -97,7 +96,7 @@ curl 127.0.0.1:8080/generate \
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more. To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
```bash ```bash
docker run ghcr.io/huggingface/text-generation-inference:2.2.0 --help docker run ghcr.io/huggingface/text-generation-inference:2.3.1 --help
``` ```
</Tip> </Tip>

View File

@ -1,9 +1,7 @@
# Supported Models and Hardware # Supported Models
Text Generation Inference enables serving optimized models on specific hardware for the highest performance. The following sections list which models (VLMs & LLMs) are supported. Text Generation Inference enables serving optimized models. The following sections list which models (VLMs & LLMs) are supported.
## Supported Models
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2) - [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal) - [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
@ -38,6 +36,7 @@ Text Generation Inference enables serving optimized models on specific hardware
- [Mllama](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) (Multimodal) - [Mllama](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) (Multimodal)
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models: If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
```python ```python

View File

@ -978,11 +978,11 @@
"nixpkgs": "nixpkgs_6" "nixpkgs": "nixpkgs_6"
}, },
"locked": { "locked": {
"lastModified": 1728029332, "lastModified": 1728381423,
"narHash": "sha256-j0RX3a67lvi2PC5w6J5DHTxM+l96J/OV5sAf34IUfUo=", "narHash": "sha256-gpHy1WtlA8ZTd8XmxsdCoDd4Z7DE7co37lH7P+nsADA=",
"owner": "huggingface", "owner": "huggingface",
"repo": "text-generation-inference-nix", "repo": "text-generation-inference-nix",
"rev": "98049f853346ca780b81fee730715c90d33ac2b4", "rev": "93123736c97e9f7bfe825bfaf3d7de0fc9a21a1e",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -492,6 +492,7 @@ def launcher(event_loop):
try: try:
container = client.containers.get(container_name) container = client.containers.get(container_name)
container.stop() container.stop()
container.remove()
container.wait() container.wait()
except NotFound: except NotFound:
pass pass
@ -514,13 +515,28 @@ def launcher(event_loop):
volumes = [f"{DOCKER_VOLUME}:/data"] volumes = [f"{DOCKER_VOLUME}:/data"]
if DOCKER_DEVICES: if DOCKER_DEVICES:
devices = DOCKER_DEVICES.split(",") if DOCKER_DEVICES.lower() == "none":
devices = []
else:
devices = DOCKER_DEVICES.strip().split(",")
visible = os.getenv("ROCR_VISIBLE_DEVICES") visible = os.getenv("ROCR_VISIBLE_DEVICES")
if visible: if visible:
env["ROCR_VISIBLE_DEVICES"] = visible env["ROCR_VISIBLE_DEVICES"] = visible
device_requests = [] device_requests = []
if not devices:
devices = None
elif devices == ["nvidia.com/gpu=all"]:
devices = None
device_requests = [
docker.types.DeviceRequest(
driver="cdi",
# count=gpu_count,
device_ids=[f"nvidia.com/gpu={i}"],
)
for i in range(gpu_count)
]
else: else:
devices = [] devices = None
device_requests = [ device_requests = [
docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]]) docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
] ]
@ -540,21 +556,26 @@ def launcher(event_loop):
shm_size="1G", shm_size="1G",
) )
yield ContainerLauncherHandle(client, container.name, port)
if not use_flash_attention:
del env["USE_FLASH_ATTENTION"]
try: try:
container.stop() yield ContainerLauncherHandle(client, container.name, port)
container.wait()
except NotFound:
pass
container_output = container.logs().decode("utf-8") if not use_flash_attention:
print(container_output, file=sys.stderr) del env["USE_FLASH_ATTENTION"]
container.remove() try:
container.stop()
container.wait()
except NotFound:
pass
container_output = container.logs().decode("utf-8")
print(container_output, file=sys.stderr)
finally:
try:
container.remove()
except Exception:
pass
if DOCKER_IMAGE is not None: if DOCKER_IMAGE is not None:
return docker_launcher return docker_launcher

View File

@ -0,0 +1,104 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1824,
"logprob": -12.296875,
"text": "What"
},
{
"id": 349,
"logprob": -0.97216797,
"text": "is"
},
{
"id": 3534,
"logprob": -10.1796875,
"text": "deep"
},
{
"id": 5168,
"logprob": -0.9658203,
"text": "learning"
},
{
"id": 28804,
"logprob": -0.44384766,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -0.50878906,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.8876953,
"special": false,
"text": "\n"
},
{
"id": 23229,
"logprob": -0.15124512,
"special": false,
"text": "Deep"
},
{
"id": 5168,
"logprob": -0.030288696,
"special": false,
"text": " learning"
},
{
"id": 349,
"logprob": -0.16687012,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.17858887,
"special": false,
"text": " a"
},
{
"id": 19804,
"logprob": -0.8046875,
"special": false,
"text": " subset"
},
{
"id": 302,
"logprob": -0.007205963,
"special": false,
"text": " of"
},
{
"id": 5599,
"logprob": -0.090026855,
"special": false,
"text": " machine"
},
{
"id": 5168,
"logprob": -0.0030670166,
"special": false,
"text": " learning"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a subset of machine learning"
}

View File

@ -0,0 +1,99 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 349,
"logprob": -13.921875,
"text": "is"
},
{
"id": 3534,
"logprob": -11.2265625,
"text": "deep"
},
{
"id": 5168,
"logprob": -2.3886719,
"text": "learning"
},
{
"id": 28804,
"logprob": -4.7109375,
"text": "?"
}
],
"seed": 0,
"tokens": [
{
"id": 13,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 23229,
"logprob": -0.5229492,
"special": false,
"text": "Deep"
},
{
"id": 17504,
"logprob": 0.0,
"special": false,
"text": " Learning"
},
{
"id": 349,
"logprob": -0.5151367,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 19804,
"logprob": 0.0,
"special": false,
"text": " subset"
},
{
"id": 302,
"logprob": 0.0,
"special": false,
"text": " of"
},
{
"id": 13253,
"logprob": -1.3359375,
"special": false,
"text": " Machine"
},
{
"id": 17504,
"logprob": 0.0,
"special": false,
"text": " Learning"
},
{
"id": 28725,
"logprob": 0.0,
"special": false,
"text": ","
}
],
"top_tokens": null
},
"generated_text": "What is deep learning?\nDeep Learning is a subset of Machine Learning,"
}

View File

@ -0,0 +1,418 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1824,
"logprob": -12.296875,
"text": "What"
},
{
"id": 349,
"logprob": -0.97216797,
"text": "is"
},
{
"id": 3534,
"logprob": -10.1796875,
"text": "deep"
},
{
"id": 5168,
"logprob": -0.9658203,
"text": "learning"
},
{
"id": 28804,
"logprob": -0.44384766,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -0.50878906,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.8876953,
"special": false,
"text": "\n"
},
{
"id": 23229,
"logprob": -0.15136719,
"special": false,
"text": "Deep"
},
{
"id": 5168,
"logprob": -0.030273438,
"special": false,
"text": " learning"
},
{
"id": 349,
"logprob": -0.1665039,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.1776123,
"special": false,
"text": " a"
},
{
"id": 19804,
"logprob": -0.8076172,
"special": false,
"text": " subset"
},
{
"id": 302,
"logprob": -0.007183075,
"special": false,
"text": " of"
},
{
"id": 5599,
"logprob": -0.090148926,
"special": false,
"text": " machine"
},
{
"id": 5168,
"logprob": -0.0030670166,
"special": false,
"text": " learning"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a subset of machine learning"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1824,
"logprob": -12.34375,
"text": "What"
},
{
"id": 349,
"logprob": -0.96728516,
"text": "is"
},
{
"id": 3534,
"logprob": -10.1796875,
"text": "deep"
},
{
"id": 5168,
"logprob": -0.97265625,
"text": "learning"
},
{
"id": 28804,
"logprob": -0.44189453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -0.51220703,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.87402344,
"special": false,
"text": "\n"
},
{
"id": 23229,
"logprob": -0.15039062,
"special": false,
"text": "Deep"
},
{
"id": 5168,
"logprob": -0.030288696,
"special": false,
"text": " learning"
},
{
"id": 349,
"logprob": -0.1652832,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.17858887,
"special": false,
"text": " a"
},
{
"id": 19804,
"logprob": -0.81103516,
"special": false,
"text": " subset"
},
{
"id": 302,
"logprob": -0.007183075,
"special": false,
"text": " of"
},
{
"id": 5599,
"logprob": -0.08880615,
"special": false,
"text": " machine"
},
{
"id": 5168,
"logprob": -0.0030612946,
"special": false,
"text": " learning"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a subset of machine learning"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1824,
"logprob": -12.34375,
"text": "What"
},
{
"id": 349,
"logprob": -0.96728516,
"text": "is"
},
{
"id": 3534,
"logprob": -10.1796875,
"text": "deep"
},
{
"id": 5168,
"logprob": -0.97265625,
"text": "learning"
},
{
"id": 28804,
"logprob": -0.44189453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -0.51220703,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.87402344,
"special": false,
"text": "\n"
},
{
"id": 23229,
"logprob": -0.15039062,
"special": false,
"text": "Deep"
},
{
"id": 5168,
"logprob": -0.030288696,
"special": false,
"text": " learning"
},
{
"id": 349,
"logprob": -0.1652832,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.17858887,
"special": false,
"text": " a"
},
{
"id": 19804,
"logprob": -0.81103516,
"special": false,
"text": " subset"
},
{
"id": 302,
"logprob": -0.007183075,
"special": false,
"text": " of"
},
{
"id": 5599,
"logprob": -0.08880615,
"special": false,
"text": " machine"
},
{
"id": 5168,
"logprob": -0.0030612946,
"special": false,
"text": " learning"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a subset of machine learning"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1824,
"logprob": -12.34375,
"text": "What"
},
{
"id": 349,
"logprob": -0.96728516,
"text": "is"
},
{
"id": 3534,
"logprob": -10.1796875,
"text": "deep"
},
{
"id": 5168,
"logprob": -0.97265625,
"text": "learning"
},
{
"id": 28804,
"logprob": -0.44189453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -0.51220703,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.87402344,
"special": false,
"text": "\n"
},
{
"id": 23229,
"logprob": -0.15039062,
"special": false,
"text": "Deep"
},
{
"id": 5168,
"logprob": -0.030288696,
"special": false,
"text": " learning"
},
{
"id": 349,
"logprob": -0.1652832,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.17858887,
"special": false,
"text": " a"
},
{
"id": 19804,
"logprob": -0.81103516,
"special": false,
"text": " subset"
},
{
"id": 302,
"logprob": -0.007183075,
"special": false,
"text": " of"
},
{
"id": 5599,
"logprob": -0.08880615,
"special": false,
"text": " machine"
},
{
"id": 5168,
"logprob": -0.0030612946,
"special": false,
"text": " learning"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a subset of machine learning"
}
]

View File

@ -1,38 +1,26 @@
{ {
"choices": [ "choices": [
{ {
"finish_reason": "eos_token", "finish_reason": "stop",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"message": { "message": {
"content": null, "content": "I am an AI assistant",
"name": null, "name": null,
"role": "assistant", "role": "assistant",
"tool_calls": [ "tool_calls": null
{
"function": {
"arguments": {
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
},
"description": null,
"name": "notify_error"
},
"id": 0,
"type": "function"
}
]
}, },
"usage": null "usage": null
} }
], ],
"created": 1712852597, "created": 1728497062,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "chat.completion",
"system_fingerprint": "1.4.5-native", "system_fingerprint": "2.3.2-dev0-native",
"usage": { "usage": {
"completion_tokens": 39, "completion_tokens": 23,
"prompt_tokens": 496, "prompt_tokens": 604,
"total_tokens": 535 "total_tokens": 627
} }
} }

View File

@ -0,0 +1,20 @@
{
"choices": [
{
"delta": {
"content": " assistant",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1728497531,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native",
"usage": null
}

View File

@ -0,0 +1,20 @@
{
"choices": [
{
"delta": {
"content": " fans",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1728497461,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native",
"usage": null
}

View File

@ -16,7 +16,7 @@ async def flash_gemma(flash_gemma_handle):
@pytest.mark.release @pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma(flash_gemma, response_snapshot): async def test_flash_gemma_simple(flash_gemma, response_snapshot):
response = await flash_gemma.generate( response = await flash_gemma.generate(
"Test request", max_new_tokens=10, decoder_input_details=True "Test request", max_new_tokens=10, decoder_input_details=True
) )

View File

@ -15,7 +15,7 @@ async def flash_llama(flash_llama_handle):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama(flash_llama, response_snapshot): async def test_flash_llama_simple(flash_llama, response_snapshot):
response = await flash_llama.generate( response = await flash_llama.generate(
"Test request", max_new_tokens=10, decoder_input_details=True "Test request", max_new_tokens=10, decoder_input_details=True
) )

View File

@ -0,0 +1,73 @@
import pytest
@pytest.fixture(scope="module")
def flash_mixtral_awq_handle(launcher):
with launcher("casperhansen/mixtral-instruct-awq", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_mixtral_awq(flash_mixtral_awq_handle):
await flash_mixtral_awq_handle.health(300)
return flash_mixtral_awq_handle.client
@pytest.mark.asyncio
async def test_flash_mixtral_awq(flash_mixtral_awq, response_snapshot):
response = await flash_mixtral_awq.generate(
"What is deep learning?", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert (
response.generated_text == "\n\nDeep learning is a subset of machine learning"
)
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_mixtral_awq_all_params(flash_mixtral_awq, response_snapshot):
response = await flash_mixtral_awq.generate(
"What is deep learning?",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "What is deep learning?\nDeep Learning is a subset of Machine Learning,"
)
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_mixtral_awq_load(
flash_mixtral_awq, generate_load, response_snapshot
):
responses = await generate_load(
flash_mixtral_awq, "What is deep learning?", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert responses[0].details.generated_tokens == 10
assert (
responses[0].generated_text
== "\n\nDeep learning is a subset of machine learning"
)
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert responses == response_snapshot

View File

@ -207,11 +207,20 @@ async def test_flash_llama_grammar_tools_stream(
) )
count = 0 count = 0
tool_calls_generated = ""
last_response = None
async for response in responses: async for response in responses:
count += 1 count += 1
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
last_response = response
assert response.choices[0].delta.content is None
assert (
tool_calls_generated
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Paris, France"}}<|eot_id|>'
)
assert count == 28 assert count == 28
assert response == response_snapshot assert last_response == response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@ -227,18 +236,94 @@ async def test_flash_llama_grammar_tools_insufficient_information(
messages=[ messages=[
{ {
"role": "system", "role": "system",
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION", "content": "You're a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "Who are you?",
},
],
stream=False,
)
assert responses.choices[0].message.tool_calls is None
assert responses.choices[0].message.content == "I am an AI assistant"
assert responses == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information_stream(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=24,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "Who are you?",
},
],
stream=True,
)
count = 0
content_generated = ""
last_response = None
async for response in responses:
count += 1
content_generated += response.choices[0].delta.content
last_response = response
assert response.choices[0].delta.tool_calls is None
assert count == 5
assert content_generated == "I am an AI assistant"
assert last_response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=24,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
}, },
{ {
"role": "user", "role": "user",
"content": "Tell me a story about 3 sea creatures", "content": "Tell me a story about 3 sea creatures",
}, },
], ],
stream=False, stream=True,
) )
assert responses.choices[0].message.content is None count = 0
content_generated = ""
last_response = None
async for response in responses:
count += 1
content_generated += response.choices[0].delta.content
last_response = response
assert response.choices[0].delta.tool_calls is None
assert count == 62
assert ( assert (
responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error" content_generated
== "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans"
) )
assert responses == response_snapshot assert last_response == response_snapshot

View File

@ -13,3 +13,6 @@ pytest = "^7.4.0"
pytest-asyncio = "^0.21.1" pytest-asyncio = "^0.21.1"
docker = "^7" docker = "^7"
numpy = "^1.20" numpy = "^1.20"
[tool.isort]
profile = "black"

View File

@ -944,17 +944,19 @@ fn shard_manager(
} }
}); });
// We read stdin in another thread as it seems that lines() can block in some cases // We read stdin in another thread as it seems that lines() can block in some cases
thread::spawn(move || { if LevelFilter::current() >= tracing::Level::DEBUG {
let mut stdin = io::stdin(); // We get `Stdin` here. thread::spawn(move || {
loop { let mut stdin = io::stdin(); // We get `Stdin` here.
let mut buffer = vec![0; 4096]; loop {
if let Ok(n) = stdin.read(&mut buffer) { let mut buffer = vec![0; 4096];
if n > 0 { if let Ok(n) = stdin.read(&mut buffer) {
let _ = pstdin.write_all(&buffer[..n]); if n > 0 {
let _ = pstdin.write_all(&buffer[..n]);
}
} }
} }
} });
}); }
let mut ready = false; let mut ready = false;
let start_time = Instant::now(); let start_time = Instant::now();

View File

@ -1,5 +1,7 @@
{ {
mkShell, mkShell,
black,
isort,
openssl, openssl,
pkg-config, pkg-config,
protobuf, protobuf,
@ -14,6 +16,8 @@
mkShell { mkShell {
buildInputs = buildInputs =
[ [
black
isort
openssl.dev openssl.dev
pkg-config pkg-config
(rust-bin.stable.latest.default.override { (rust-bin.stable.latest.default.override {

View File

@ -355,6 +355,8 @@ pub enum InferError {
MissingTemplateVariable(String), MissingTemplateVariable(String),
#[error("Tool error: {0}")] #[error("Tool error: {0}")]
ToolError(String), ToolError(String),
#[error("Stream event serialization error")]
StreamSerializationError(String),
} }
impl InferError { impl InferError {
@ -368,6 +370,7 @@ impl InferError {
InferError::TemplateError(_) => "template_error", InferError::TemplateError(_) => "template_error",
InferError::MissingTemplateVariable(_) => "missing_template_variable", InferError::MissingTemplateVariable(_) => "missing_template_variable",
InferError::ToolError(_) => "tool_error", InferError::ToolError(_) => "tool_error",
InferError::StreamSerializationError(_) => "stream_serialization_error",
} }
} }
} }

View File

@ -31,32 +31,29 @@ impl ToolGrammar {
let mut tools = tools.clone(); let mut tools = tools.clone();
// add the notify_error function to the tools // add the no_tool function to the tools
let notify_error = Tool { let no_tool = Tool {
r#type: "function".to_string(), r#type: "function".to_string(),
function: FunctionDefinition { function: FunctionDefinition {
name: "notify_error".to_string(), name: "no_tool".to_string(),
description: Some("Notify an error or issue".to_string()), description: Some("Open ened response with no specific tool selected".to_string()),
arguments: json!({ arguments: json!({
"type": "object", "type": "object",
"properties": { "properties": {
"error": { "content": {
"type": "string", "type": "string",
"description": "The error or issue to notify" "description": "The response content",
} }
}, },
"required": ["error"] "required": ["content"]
}), }),
}, },
}; };
tools.push(notify_error); tools.push(no_tool);
// if tools are provided and no tool_choice we default to the OneOf // if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice { let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => { ToolType::Function(function) => {
vec![Self::find_tool_by_name(&tools, &name)?]
}
ToolType::Function { function } => {
vec![Self::find_tool_by_name(&tools, &function.name)?] vec![Self::find_tool_by_name(&tools, &function.name)?]
} }
ToolType::OneOf => tools.clone(), ToolType::OneOf => tools.clone(),

View File

@ -957,12 +957,18 @@ pub fn default_tool_prompt() -> String {
} }
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] #[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
#[serde(untagged)] #[schema(example = "auto")]
/// Controls which (if any) tool is called by the model.
pub enum ToolType { pub enum ToolType {
/// Means the model can pick between generating a message or calling one or more tools.
#[schema(rename = "auto")]
OneOf, OneOf,
FunctionName(String), /// Means the model will not call any tool and instead generates a message.
Function { function: FunctionName }, #[schema(rename = "none")]
NoTool, NoTool,
/// Forces the model to call a specific tool.
#[schema(rename = "function")]
Function(FunctionName),
} }
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)]
@ -977,6 +983,7 @@ pub struct ToolChoice(pub Option<ToolType>);
#[derive(Deserialize)] #[derive(Deserialize)]
#[serde(untagged)] #[serde(untagged)]
enum ToolTypeDeserializer { enum ToolTypeDeserializer {
Null,
String(String), String(String),
ToolType(ToolType), ToolType(ToolType),
} }
@ -984,10 +991,11 @@ enum ToolTypeDeserializer {
impl From<ToolTypeDeserializer> for ToolChoice { impl From<ToolTypeDeserializer> for ToolChoice {
fn from(value: ToolTypeDeserializer) -> Self { fn from(value: ToolTypeDeserializer) -> Self {
match value { match value {
ToolTypeDeserializer::Null => ToolChoice(None),
ToolTypeDeserializer::String(s) => match s.as_str() { ToolTypeDeserializer::String(s) => match s.as_str() {
"none" => ToolChoice(Some(ToolType::NoTool)), "none" => ToolChoice(Some(ToolType::NoTool)),
"auto" => ToolChoice(Some(ToolType::OneOf)), "auto" => ToolChoice(Some(ToolType::OneOf)),
_ => ToolChoice(Some(ToolType::FunctionName(s))), _ => ToolChoice(Some(ToolType::Function(FunctionName { name: s }))),
}, },
ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)), ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)),
} }

View File

@ -42,6 +42,7 @@ use hf_hub::{Cache, Repo, RepoType};
use http::header::AUTHORIZATION; use http::header::AUTHORIZATION;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use pyo3::types::IntoPyDict; use pyo3::types::IntoPyDict;
use regex::Regex;
use serde_json::Value; use serde_json::Value;
use std::convert::Infallible; use std::convert::Infallible;
use std::fs::File; use std::fs::File;
@ -452,12 +453,20 @@ async fn generate_stream(
Sse<impl Stream<Item = Result<Event, Infallible>>>, Sse<impl Stream<Item = Result<Event, Infallible>>>,
) { ) {
let span = tracing::Span::current(); let span = tracing::Span::current();
let on_message_callback = |stream_token: StreamResponse| {
let event = Event::default();
event.json_data(stream_token).unwrap()
};
let (headers, response_stream) = let (headers, response_stream) =
generate_stream_internal(infer, compute_type, Json(req), on_message_callback, span).await; generate_stream_internal(infer, compute_type, Json(req), span).await;
let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream);
while let Some(raw_event) = response_stream.next().await {
yield Ok(raw_event.map_or_else(Event::from, |token| {
Event::default()
.json_data(token)
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into())
}));
}
};
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
(headers, sse) (headers, sse)
} }
@ -466,9 +475,11 @@ async fn generate_stream_internal(
infer: Infer, infer: Infer,
ComputeType(compute_type): ComputeType, ComputeType(compute_type): ComputeType,
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
on_message_callback: impl Fn(StreamResponse) -> Event,
span: tracing::Span, span: tracing::Span,
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) { ) -> (
HeaderMap,
impl Stream<Item = Result<StreamResponse, InferError>>,
) {
let start_time = Instant::now(); let start_time = Instant::now();
metrics::counter!("tgi_request_count").increment(1); metrics::counter!("tgi_request_count").increment(1);
@ -500,12 +511,12 @@ async fn generate_stream_internal(
let err = InferError::from(ValidationError::BestOfStream); let err = InferError::from(ValidationError::BestOfStream);
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)); yield Err(err);
} else if req.parameters.decoder_input_details { } else if req.parameters.decoder_input_details {
let err = InferError::from(ValidationError::PrefillDetailsStream); let err = InferError::from(ValidationError::PrefillDetailsStream);
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)); yield Err(err);
} else { } else {
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives // Keep permit as long as generate_stream lives
@ -535,8 +546,7 @@ async fn generate_stream_internal(
generated_text: None, generated_text: None,
details: None, details: None,
}; };
let event = on_message_callback(stream_token); yield Ok(stream_token);
yield Ok(event);
} }
// Yield event for last token and compute timings // Yield event for last token and compute timings
InferStreamResponse::End { InferStreamResponse::End {
@ -600,9 +610,7 @@ async fn generate_stream_internal(
details details
}; };
yield Ok(stream_token);
let event = on_message_callback(stream_token);
yield Ok(event);
break; break;
} }
} }
@ -610,7 +618,7 @@ async fn generate_stream_internal(
// yield error // yield error
Err(err) => { Err(err) => {
error = true; error = true;
yield Ok(Event::from(err)); yield Err(err);
break; break;
} }
} }
@ -619,7 +627,7 @@ async fn generate_stream_internal(
// yield error // yield error
Err(err) => { Err(err) => {
error = true; error = true;
yield Ok(Event::from(err)); yield Err(err);
} }
} }
// Check if generation reached the end // Check if generation reached the end
@ -628,7 +636,7 @@ async fn generate_stream_internal(
let err = InferError::IncompleteGenerationStream; let err = InferError::IncompleteGenerationStream;
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1); metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)); yield Err(err);
} }
} }
}; };
@ -771,75 +779,85 @@ async fn completions(
// Create a future for each generate_stream_internal call. // Create a future for each generate_stream_internal call.
let generate_future = async move { let generate_future = async move {
let on_message_callback = move |stream_token: StreamResponse| {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let message = match stream_token.details {
Some(details) => {
let completion_tokens = details.generated_tokens;
let prompt_tokens = details.input_length;
let total_tokens = prompt_tokens + completion_tokens;
Completion::Final(CompletionFinal {
id: String::new(),
created: current_time,
model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(),
choices: vec![CompletionComplete {
finish_reason: details.finish_reason.to_string(),
index: index as u32,
logprobs: None,
text: stream_token.token.text,
}],
usage: Usage {
prompt_tokens,
completion_tokens,
total_tokens,
},
})
}
None => Completion::Chunk(Chunk {
id: String::new(),
created: current_time,
choices: vec![CompletionComplete {
finish_reason: String::new(),
index: index as u32,
logprobs: None,
text: stream_token.token.text,
}],
model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(),
}),
};
event
.json_data(message)
.unwrap_or_else(|_e| Event::default())
};
let (header_tx, header_rx) = oneshot::channel(); let (header_tx, header_rx) = oneshot::channel();
let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel(); let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel();
tokio::spawn(async move { tokio::spawn(async move {
let (header_map, sse) = generate_stream_internal( let (headers, response_stream) = generate_stream_internal(
infer_clone.clone(), infer_clone.clone(),
compute_type_clone.clone(), compute_type_clone.clone(),
Json(generate_request), Json(generate_request),
on_message_callback,
span_clone.clone(), span_clone.clone(),
) )
.await; .await;
let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream);
while let Some(stream_token) = response_stream.next().await {
match stream_token {
Ok(stream_token) => {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let message = match stream_token.details {
Some(details) => {
let completion_tokens = details.generated_tokens;
let prompt_tokens = details.input_length;
let total_tokens = prompt_tokens + completion_tokens;
Completion::Final(CompletionFinal {
id: String::new(),
created: current_time,
model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(),
choices: vec![CompletionComplete {
finish_reason: details.finish_reason.to_string(),
index: index as u32,
logprobs: None,
text: stream_token.token.text,
}],
usage: Usage {
prompt_tokens,
completion_tokens,
total_tokens,
},
})
}
None => Completion::Chunk(Chunk {
id: String::new(),
created: current_time,
choices: vec![CompletionComplete {
finish_reason: String::new(),
index: index as u32,
logprobs: None,
text: stream_token.token.text,
}],
model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(),
}),
};
let event = event
.json_data(message)
.unwrap_or_else(|_e| Event::default());
yield Ok(event);
}
Err(err) => yield Ok(Event::from(err)),
}
}
};
// send and dont wait for response // send and dont wait for response
let _ = header_tx.send(header_map); let _ = header_tx.send(headers);
// pin an emit messages to the sse_tx // pin an emit messages to the sse_tx
let mut sse = Box::pin(sse); let mut sse = Box::pin(response_stream);
while let Some(event) = sse.next().await { while let Some(event) = sse.next().await {
if sse_tx.send(event).is_err() { if sse_tx.send(event).is_err() {
tracing::error!("Failed to send event. Receiver dropped."); tracing::error!("Failed to send event. Receiver dropped.");
@ -1072,6 +1090,84 @@ async fn completions(
} }
} }
enum StreamState {
Buffering,
BufferTrailing,
Content { skip_close_quote: bool },
}
/// Convert a StreamResponse into an Event to be sent over SSE
fn create_event_from_stream_token(
stream_token: &StreamResponse,
logprobs: bool,
stream_options: Option<StreamOptions>,
inner_using_tools: bool,
system_fingerprint: String,
model_id: String,
) -> Event {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let logprobs = logprobs.then(|| {
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens.clone()))
});
// replace the content with the tool calls if grammar is present
let (content, tool_calls) = if inner_using_tools {
(None, Some(vec![stream_token.token.text.clone()]))
} else {
let content = if !stream_token.token.special {
Some(stream_token.token.text.clone())
} else {
None
};
(content, None)
};
let (usage, finish_reason) = match &stream_token.details {
Some(details) => {
let usage = if stream_options
.as_ref()
.map(|s| s.include_usage)
.unwrap_or(false)
{
let completion_tokens = details.generated_tokens;
let prompt_tokens = details.input_length;
let total_tokens = prompt_tokens + completion_tokens;
Some(Usage {
completion_tokens,
prompt_tokens,
total_tokens,
})
} else {
None
};
(usage, Some(details.finish_reason.format(true)))
}
None => (None, None),
};
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
content,
tool_calls,
current_time,
logprobs,
finish_reason,
usage,
));
event.json_data(chat_complete).unwrap_or_else(|e| {
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
Event::default()
})
}
/// Generate tokens /// Generate tokens
#[utoipa::path( #[utoipa::path(
post, post,
@ -1128,88 +1224,135 @@ async fn chat_completions(
// static values that will be returned in all cases // static values that will be returned in all cases
let model_id = info.model_id.clone(); let model_id = info.model_id.clone();
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
// switch on stream // switch on stream
if stream { if stream {
// pass this callback to the stream generation and build the required event structure let (headers, response_stream) =
let on_message_callback = move |stream_token: StreamResponse| { generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
let event = Event::default();
let current_time = std::time::SystemTime::now() // regex to match any function name
.duration_since(std::time::UNIX_EPOCH) let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) Ok(regex) => regex,
.as_secs(); Err(e) => {
return Err((
let logprobs = logprobs.then(|| { StatusCode::INTERNAL_SERVER_ERROR,
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens)) Json(ErrorResponse {
}); error: format!("Failed to compile regex: {}", e),
error_type: "regex".to_string(),
// replace the content with the tool calls if grammar is present }),
let (content, tool_calls) = if using_tools {
(None, Some(vec![stream_token.token.text]))
} else {
let content = if !stream_token.token.special {
Some(stream_token.token.text)
} else {
None
};
(content, None)
};
let (usage, finish_reason) = match stream_token.details {
Some(details) => {
let usage = if stream_options
.as_ref()
.map(|s| s.include_usage)
.unwrap_or(false)
{
let completion_tokens = details.generated_tokens;
let prompt_tokens = details.input_length;
let total_tokens = prompt_tokens + completion_tokens;
Some(Usage {
completion_tokens,
prompt_tokens,
total_tokens,
})
} else {
None
};
(usage, Some(details.finish_reason.format(true)))
}
None => (None, None),
};
event
.json_data(CompletionType::ChatCompletionChunk(
ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
content,
tool_calls,
current_time,
logprobs,
finish_reason,
usage,
),
)) ))
.unwrap_or_else(|e| { }
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
Event::default()
})
}; };
let (headers, response_stream) = generate_stream_internal( let response_stream = async_stream::stream! {
infer, let mut response_stream = Box::pin(response_stream);
compute_type, let mut buffer = Vec::new();
Json(generate_request), let mut json_buffer = String::new();
on_message_callback, let mut state = if using_tools {
span, StreamState::Buffering
) } else {
.await; StreamState::Content {
skip_close_quote: false,
}
};
let mut response_as_tool = using_tools;
while let Some(result) = response_stream.next().await {
if let Ok(stream_token) = result {
let token_text = &stream_token.token.text.clone();
match state {
StreamState::Buffering => {
json_buffer.push_str(&token_text.replace(" ", ""));
buffer.push(stream_token);
if let Some(captures) = function_regex.captures(&json_buffer) {
let function_name = captures[1].to_string();
if function_name == "no_tool" {
state = StreamState::BufferTrailing;
response_as_tool = false;
buffer.clear();
json_buffer.clear();
} else {
state = StreamState::Content {
skip_close_quote: false,
};
// send all the buffered messages
for stream_token in &buffer {
let event = create_event_from_stream_token(
stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
}
}
}
}
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
StreamState::BufferTrailing => {
let infix_text = "\"content\":\"";
json_buffer.push_str(&token_text.replace(" ", ""));
// keep capturing until we find the infix text
match json_buffer.find(infix_text) {
Some(content_key_index) => {
json_buffer =
json_buffer[content_key_index + infix_text.len()..].to_string();
}
None => {
continue;
}
}
// if there is leftover text after removing the infix text, we need to send it
if !json_buffer.is_empty() {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let chat_complete =
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
Some(json_buffer.clone()),
None,
current_time,
None,
None,
None,
));
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
InferError::StreamSerializationError(e.to_string()).into()
}));
}
// cleanup the buffers
buffer.clear();
json_buffer.clear();
state = StreamState::Content {
skip_close_quote: true,
};
}
StreamState::Content { skip_close_quote } => {
if skip_close_quote && token_text.contains('"') {
break;
}
let response_stream = response_stream.chain(futures::stream::once(async { // send the content
Ok(Event::default().data("[DONE]")) let event = create_event_from_stream_token(
})); &stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
}
}
}
}
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
};
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response()) Ok((headers, sse).into_response())
@ -1246,17 +1389,33 @@ async fn chat_completions(
if let Value::Object(ref mut props) = arguments { if let Value::Object(ref mut props) = arguments {
props.remove("_name"); props.remove("_name");
} }
match name.as_str() {
let tool_calls = vec![ToolCall { "no_tool" => {
id: "0".to_string(), // parse the content message
r#type: "function".to_string(), let content_message = arguments
function: FunctionDefinition { .get("content")
description: None, .and_then(Value::as_str)
name, .ok_or_else(|| {
arguments, InferError::ToolError(
}, "No `content` found in generated text".to_string(),
}]; )
(Some(tool_calls), None) })?
.to_string();
(None, Some(content_message))
}
_ => {
let tool_calls = vec![ToolCall {
id: "0".to_string(),
r#type: "function".to_string(),
function: FunctionDefinition {
description: None,
name,
arguments,
},
}];
(Some(tool_calls), None)
}
}
} else { } else {
(None, Some(generation.generated_text)) (None, Some(generation.generated_text))
}; };
@ -2323,6 +2482,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::StreamSerializationError(_) => StatusCode::INTERNAL_SERVER_ERROR,
}; };
( (
@ -2500,8 +2660,8 @@ mod tests {
); );
assert!(result.is_ok()); assert!(result.is_ok());
let (inputs, _grammar, using_tools) = result.unwrap(); let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input");
assert_eq!(using_tools, true); assert_eq!(using_tools, true);
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ened response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
} }
} }

View File

@ -1,5 +1,5 @@
[toolchain] [toolchain]
# Released on: June 13, 2024 # Released on: June 13, 2024
# https://releases.rs/docs/1.79.0/ # https://releases.rs/docs/1.79.0/
channel = "1.80.0" channel = "1.80.1"
components = ["rustfmt", "clippy"] components = ["rustfmt", "clippy"]

29
server/poetry.lock generated
View File

@ -1269,12 +1269,12 @@ files = [
[[package]] [[package]]
name = "moe-kernels" name = "moe-kernels"
version = "0.4.0" version = "0.6.0"
description = "MoE kernels" description = "MoE kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:3fc0475bb3b9c09bbf08f6f6e9767d10eaba55b558f67a605fe70ae0cbb5e6a4"}, {file = "moe_kernels-0.6.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:f28fd2a56c3ac7bfe74bc44cc7c8c0791a2644ad689b084ea4ed6decb7f41c25"},
] ]
[package.dependencies] [package.dependencies]
@ -1284,16 +1284,16 @@ triton = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl" url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
[[package]] [[package]]
name = "moe-kernels" name = "moe-kernels"
version = "0.4.0" version = "0.6.0"
description = "MoE kernels" description = "MoE kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:8ca72a064ceb84a23a3437cc6e6363907ad41588877f6acb1febc010fc7beb22"}, {file = "moe_kernels-0.6.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:db475948fd9f7a8647aa3f73256ff4d3bb111425305bcd0b0d3559ccc75b8937"},
] ]
[package.dependencies] [package.dependencies]
@ -1303,16 +1303,16 @@ triton = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl" url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
[[package]] [[package]]
name = "moe-kernels" name = "moe-kernels"
version = "0.4.0" version = "0.6.0"
description = "MoE kernels" description = "MoE kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:d302d6b16bb4905b2312dc68da6a6f51e87d0cd3c4bf1f23d995501162399a8e"}, {file = "moe_kernels-0.6.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:364be07c06aafbab1f51d9e26d9a4ff658defe1462a4c645abaf7b895ed163a8"},
] ]
[package.dependencies] [package.dependencies]
@ -1322,16 +1322,16 @@ triton = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl" url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
[[package]] [[package]]
name = "moe-kernels" name = "moe-kernels"
version = "0.4.0" version = "0.6.0"
description = "MoE kernels" description = "MoE kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:6aee3e723efa5113c338b40e6cb20fa62da6c442c65c1a6cc97751d34158a93a"}, {file = "moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:81e7fa25fb5ed5336f5151994f5e3f600df7e166fe013576968c59415e442894"},
] ]
[package.dependencies] [package.dependencies]
@ -1341,7 +1341,7 @@ triton = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl" url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
[[package]] [[package]]
name = "mpmath" name = "mpmath"
@ -3402,11 +3402,6 @@ files = [
{file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"},
{file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"},
{file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"},
{file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"},
{file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"},
{file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"},
{file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"},
{file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"},
] ]
[package.dependencies] [package.dependencies]

View File

@ -47,10 +47,10 @@ marlin-kernels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
] ]
moe-kernels = [ moe-kernels = [
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
] ]
rich = "^13.7.1" rich = "^13.7.1"
@ -82,3 +82,6 @@ requires = [
"poetry-core>=1.0.0", "poetry-core>=1.0.0",
] ]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.isort]
profile = "black"

View File

@ -68,5 +68,5 @@ else:
def clamp(self, max): def clamp(self, max):
if SYSTEM == "rocm": if SYSTEM == "rocm":
return self return self
raise NotImplementedError("Not implemented seqlen for paged") self.input_lengths = torch.clamp(self.input_lengths, max=max)
return Seqlen(torch.clamp(self.input_lengths, max=max)) return self

View File

@ -24,10 +24,8 @@ class KVCache:
): ):
"""Construct the key-value cache for a layer.""" """Construct the key-value cache for a layer."""
if ( if dtype == torch.float8_e5m2 and (
dtype == torch.float8_e5m2 ATTENTION != "flashinfer" or SYSTEM != "cuda"
and (ATTENTION != "flashinfer"
or SYSTEM != "cuda")
): ):
raise ValueError( raise ValueError(
"float8_e5m2 KV cache is currently only supported for flashinfer on CUDA" "float8_e5m2 KV cache is currently only supported for flashinfer on CUDA"

View File

@ -43,7 +43,7 @@ def can_use_gptq_marlin(
and quant_method in {"awq", "gptq"} and quant_method in {"awq", "gptq"}
and bits in GPTQ_MARLIN_BITS and bits in GPTQ_MARLIN_BITS
and groupsize in GPTQ_MARLIN_GROUP_SIZES and groupsize in GPTQ_MARLIN_GROUP_SIZES
# We only suppord asymmetric quantization for AWQ. # We only support asymmetric quantization for AWQ.
and (sym or quant_method == "awq") and (sym or quant_method == "awq")
) )

View File

@ -210,11 +210,17 @@ class SparseMoELayer(nn.Module):
and isinstance(weights.loader.weight_class, UnquantizedWeight) and isinstance(weights.loader.weight_class, UnquantizedWeight)
) or isinstance(weights.loader, HybridFP8UnquantLoader): ) or isinstance(weights.loader, HybridFP8UnquantLoader):
cls = UnquantizedSparseMoELayer cls = UnquantizedSparseMoELayer
elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym: elif isinstance(
weights.loader, GPTQMarlinWeightsLoader
) and can_use_marlin_moe_gemm(
quant_method=weights.loader.quant_method,
quantize=weights.loader.quantize,
sym=weights.loader.sym,
):
cls = GPTQMarlinSparseMoELayer cls = GPTQMarlinSparseMoELayer
else: else:
raise ValueError( raise ValueError(
f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights" f"Unsupported weights loader: {type(weights.loader)}, sparse MoE is only supported for unquantized, AWQ, and GPTQ weights"
) )
log_once( log_once(

View File

@ -34,9 +34,10 @@ def can_use_marlin_moe_gemm(
SYSTEM == "cuda" SYSTEM == "cuda"
and fused_marlin_moe is not None and fused_marlin_moe is not None
and has_sm_8_0 and has_sm_8_0
and quantize == "gptq" and quantize in {"awq", "gptq"}
and quant_method == "gptq" and quant_method in {"awq", "gptq"}
and sym # We only support asymmetric quantization for AWQ.
and (sym or quant_method == "awq")
) )
@ -72,10 +73,15 @@ class GPTQMarlinSparseMoELayer(nn.Module):
super().__init__() super().__init__()
if not ( if not (
isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym isinstance(weights.loader, GPTQMarlinWeightsLoader)
and can_use_marlin_moe_gemm(
quant_method=weights.loader.quant_method,
quantize=weights.loader.quantize,
sym=weights.loader.sym,
)
): ):
raise ValueError( raise ValueError(
f"Unsupported weights loader: {weights.loader}, only GPTQMarlinWeightsLoader with symmetric quantization is supported" f"Unsupported weights loader: {type(weights.loader)}, only GPTQMarlinWeightsLoader with AWQ and symmetric GPTQ quantization is supported"
) )
assert (n_expert_group is None) == ( assert (n_expert_group is None) == (
@ -102,17 +108,24 @@ class GPTQMarlinSparseMoELayer(nn.Module):
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return fused_marlin_moe( return fused_marlin_moe(
x, hidden_states=x,
w1=self.gate_up_proj.qweight, w1=self.gate_up_proj.qweight,
w2=self.down_proj.qweight, w2=self.down_proj.qweight,
g_idx1=self.gate_up_proj.g_idx,
g_idx2=self.down_proj.g_idx,
perm1=self.gate_up_proj.perm,
perm2=self.down_proj.perm,
w1_scale=self.gate_up_proj.scales, w1_scale=self.gate_up_proj.scales,
w2_scale=self.down_proj.scales, w2_scale=self.down_proj.scales,
is_full_k1=self.gate_up_proj.is_full_k, w1_zeros=(
is_full_k2=self.down_proj.is_full_k, self.gate_up_proj.qzeros
if self.gate_up_proj.qzeros.numel() > 0
else None
),
w2_zeros=(
self.down_proj.qzeros if self.down_proj.qzeros.numel() > 0 else None
),
g_idx1=self.gate_up_proj.g_idx,
g_idx2=self.down_proj.g_idx,
sort_indices1=self.gate_up_proj.perm,
sort_indices2=self.down_proj.perm,
is_k_full=self.gate_up_proj.is_full_k or self.down_proj.is_full_k,
gating_output=gating_output, gating_output=gating_output,
topk=self.topk, topk=self.topk,
renormalize=self.renormalize, renormalize=self.renormalize,

View File

@ -517,14 +517,13 @@ class CausalLM(Model):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype dtype = default_dtype if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device("cpu")
device = torch.device(f"xpu:{rank}") # Float16 doesn't exist on target.
dtype = default_dtype if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.float32 if dtype is None else dtype
@ -593,8 +592,14 @@ class CausalLM(Model):
if speculator: if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel") raise RuntimeError("Speculator decoding is not enabled for AutoModel")
device_count = 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
device_count = torch.cuda.device_count()
dtype = torch.float16 if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device("xpu")
device_count = torch.xpu.device_count()
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
if quantize: if quantize:
@ -616,18 +621,17 @@ class CausalLM(Model):
torch_dtype=dtype, torch_dtype=dtype,
device_map=( device_map=(
"auto" "auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1 if device_count > 1
else None else None
), ),
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if ( if (
torch.cuda.is_available() device_count == 1
and torch.cuda.device_count() == 1
and quantize != "bitsandbytes" and quantize != "bitsandbytes"
): ):
model = model.cuda() model = model.to(device)
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None: if model.config.pad_token_id is not None:

View File

@ -558,14 +558,13 @@ class Seq2SeqLM(Model):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype dtype = default_dtype if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device("cpu")
device = torch.device(f"xpu:{rank}") # Float16 doesn't exist on target.
dtype = default_dtype if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype dtype = torch.float32 if dtype is None else dtype
@ -630,8 +629,14 @@ class Seq2SeqLM(Model):
if speculator: if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel") raise RuntimeError("Speculator decoding is not enabled for AutoModel")
device_count = 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
device_count = torch.cuda.device_count()
dtype = torch.float16 if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device("xpu")
device_count = torch.xpu.device_count()
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
if quantize: if quantize:
@ -646,14 +651,14 @@ class Seq2SeqLM(Model):
torch_dtype=dtype, torch_dtype=dtype,
device_map=( device_map=(
"auto" "auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1 if device_count > 1
else None else None
), ),
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if torch.cuda.is_available() and torch.cuda.device_count() == 1: if device_count == 1:
model = model.cuda() model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,

View File

@ -66,6 +66,11 @@ elif is_ipex_available():
empty_cache = noop empty_cache = noop
synchronize = noop synchronize = noop
get_free_memory = get_cpu_free_memory get_free_memory = get_cpu_free_memory
elif hasattr(torch, "xpu") and torch.xpu.is_available():
SYSTEM = "xpu"
empty_cache = torch.xpu.empty_cache
synchronize = torch.xpu.synchronize
get_free_memory = get_xpu_free_memory
else: else:
SYSTEM = "cpu" SYSTEM = "cpu"

View File

@ -5,14 +5,13 @@ import json
import os import os
TEMPLATE = """ TEMPLATE = """
# Supported Models and Hardware # Supported Models
Text Generation Inference enables serving optimized models on specific hardware for the highest performance. The following sections list which models (VLMs & LLMs) are supported. Text Generation Inference enables serving optimized models. The following sections list which models (VLMs & LLMs) are supported.
## Supported Models
SUPPORTED_MODELS SUPPORTED_MODELS
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models: If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
```python ```python