Merge branch 'huggingface:main' into fix/dockerfile-triton

This commit is contained in:
Yaser Jaradeh 2025-02-03 11:48:01 +01:00 committed by GitHub
commit 8ae92e5d70
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
97 changed files with 7048 additions and 4947 deletions

View File

@ -1,75 +0,0 @@
ARG CUDA_ARCH_LIST="75-real;80-real;86-real;89-real;90-real"
ARG OMPI_VERSION="4.1.7rc1"
# Build dependencies resolver stage
FROM lukemathwalker/cargo-chef:latest AS chef
WORKDIR /usr/src/text-generation-inference/backends/trtllm
FROM chef AS planner
COPY . .
RUN cargo chef prepare --recipe-path recipe.json
# CUDA dependent dependencies resolver stage
FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu24.04 AS cuda-builder
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt update && apt install -y \
build-essential \
cmake \
curl \
gcc-14 \
g++-14 \
git \
git-lfs \
libssl-dev \
libucx-dev \
ninja-build \
pkg-config \
pipx \
python3 \
python3-dev \
python3-setuptools \
tar \
wget && \
pipx ensurepath
ENV TGI_INSTALL_PREFIX=/usr/local/tgi
ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt
# Install OpenMPI
FROM cuda-builder AS mpi-builder
ARG OMPI_VERSION
ENV OMPI_TARBALL_FILENAME="openmpi-$OMPI_VERSION.tar.bz2"
RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILENAME" -P /opt/src && \
mkdir /usr/src/mpi && \
tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \
cd /usr/src/mpi && \
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \
make -j all && \
make install && \
rm -rf "/opt/src/$OMPI_TARBALL_FILENAME"
# Install TensorRT
FROM cuda-builder AS trt-builder
COPY backends/trtllm/scripts/install_tensorrt.sh /opt/install_tensorrt.sh
RUN chmod +x /opt/install_tensorrt.sh && \
/opt/install_tensorrt.sh
# Build Backend
FROM cuda-builder AS tgi-builder
WORKDIR /usr/src/text-generation-inference
# Install Rust
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \
chmod -R a+w /root/.rustup && \
chmod -R a+w /root/.cargo
ENV PATH="/root/.cargo/bin:$PATH"
RUN cargo install cargo-chef
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
ENV MPI_HOME=/usr/local/mpi

View File

@ -1,19 +0,0 @@
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/cpp
{
"name": "CUDA",
"build": {
"dockerfile": "Dockerfile_trtllm",
"context": ".."
},
"remoteEnv": {
"PATH": "${containerEnv:PATH}:/usr/local/cuda/bin",
"LD_LIBRARY_PATH": "$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64",
"XLA_FLAGS": "--xla_gpu_cuda_data_dir=/usr/local/cuda"
},
"customizations" : {
"jetbrains" : {
"backend" : "CLion"
}
}
}

View File

@ -31,16 +31,28 @@ jobs:
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }} group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true cancel-in-progress: true
runs-on: runs-on:
group: aws-highmemory-32-plus-priv group: aws-highmemory-64-plus-priv
permissions: permissions:
contents: write contents: write
packages: write packages: write
id-token: write
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Inject slug/short variables - name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1 uses: rlespinasse/github-slug-action@v4.4.1
- name: Construct harware variables - name: Inject required variables for sccache to interact with Github Actions Cache
uses: actions/github-script@v7
with:
script: |
core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || '');
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
- name: Extract TensorRT-LLM version
run: |
echo "TENSORRT_LLM_VERSION=$(grep -oP '([a-z,0-9]{40})' $GITHUB_WORKSPACE/backends/trtllm/cmake/trtllm.cmake)" >> $GITHUB_ENV
echo "TensorRT-LLM version: ${{ env.TENSORRT_LLM_VERSION }}"
- name: Construct hardware variables
shell: bash shell: bash
run: | run: |
case ${{ inputs.hardware }} in case ${{ inputs.hardware }} in
@ -52,6 +64,7 @@ jobs:
export runs_on="aws-g6-12xl-plus-priv-cache" export runs_on="aws-g6-12xl-plus-priv-cache"
export platform="" export platform=""
export extra_pytest="" export extra_pytest=""
export target=""
;; ;;
cuda-trtllm) cuda-trtllm)
export dockerfile="Dockerfile_trtllm" export dockerfile="Dockerfile_trtllm"
@ -61,15 +74,24 @@ jobs:
export runs_on="ubuntu-latest" export runs_on="ubuntu-latest"
export platform="" export platform=""
export extra_pytest="" export extra_pytest=""
if [[ "${GITHUB_REF}" == refs/tags/* ]]; then
export build_type="release";
export target="";
else
export build_type="dev";
export target="ci-runtime";
fi
;; ;;
rocm) rocm)
export dockerfile="Dockerfile_amd" export dockerfile="Dockerfile_amd"
export label_extension="-rocm" export label_extension="-rocm"
export docker_devices="/dev/kfd,/dev/dri" export docker_devices="/dev/kfd,/dev/dri"
export docker_volume="/mnt" export docker_volume="/mnt"
export runs_on="amd-gpu-runners" # This runner was deactivated.
export runs_on="ubuntu-latest"
export platform="" export platform=""
export extra_pytest="-k test_flash_gemma_gptq_load" export extra_pytest="-k test_flash_gemma_gptq_load"
export target=""
;; ;;
intel-xpu) intel-xpu)
export dockerfile="Dockerfile_intel" export dockerfile="Dockerfile_intel"
@ -79,6 +101,7 @@ jobs:
export runs_on="ubuntu-latest" export runs_on="ubuntu-latest"
export platform="xpu" export platform="xpu"
export extra_pytest="" export extra_pytest=""
export target=""
;; ;;
intel-cpu) intel-cpu)
export dockerfile="Dockerfile_intel" export dockerfile="Dockerfile_intel"
@ -89,6 +112,7 @@ jobs:
export runs_on="aws-highmemory-32-plus-priv" export runs_on="aws-highmemory-32-plus-priv"
export platform="cpu" export platform="cpu"
export extra_pytest="-k test_flash_gemma_simple" export extra_pytest="-k test_flash_gemma_simple"
export target=""
;; ;;
esac esac
echo $dockerfile echo $dockerfile
@ -105,6 +129,8 @@ jobs:
echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV
echo "EXTRA_PYTEST=${extra_pytest}" >> $GITHUB_ENV echo "EXTRA_PYTEST=${extra_pytest}" >> $GITHUB_ENV
echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV
echo "TARGET=${target}" >> $GITHUB_ENV
echo "BUILD_TYPE=${build_type}" >> $GITHUB_ENV
- name: Initialize Docker Buildx - name: Initialize Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
with: with:
@ -169,10 +195,15 @@ jobs:
GIT_SHA=${{ env.GITHUB_SHA }} GIT_SHA=${{ env.GITHUB_SHA }}
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
PLATFORM=${{ env.PLATFORM }} PLATFORM=${{ env.PLATFORM }}
build_type=${{ env.BUILD_TYPE }}
sccache_gha_enabled=on
actions_cache_url=${{ env.ACTIONS_CACHE_URL }}
actions_runtime_token=${{ env.ACTIONS_RUNTIME_TOKEN }}
target: ${{ env.TARGET }}
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }} tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }} labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=max
cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=max
- name: Final - name: Final
id: final id: final
run: | run: |
@ -214,3 +245,23 @@ jobs:
echo $DOCKER_IMAGE echo $DOCKER_IMAGE
docker pull $DOCKER_IMAGE docker pull $DOCKER_IMAGE
pytest -s -vv integration-tests ${PYTEST_FLAGS} ${EXTRA_PYTEST} pytest -s -vv integration-tests ${PYTEST_FLAGS} ${EXTRA_PYTEST}
backend_trtllm_cxx_tests:
needs: build-and-push
if: needs.build-and-push.outputs.label == '-trtllm'
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-trtllm-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
runs-on:
group: aws-g6-12xl-plus-priv-cache
container:
image: ${{ needs.build-and-push.outputs.docker_image }}
credentials:
username: ${{ secrets.REGISTRY_USERNAME }}
password: ${{ secrets.REGISTRY_PASSWORD }}
options: --gpus all --shm-size=8g
steps:
- name: Run C++/CUDA tests
if: ${{ env.LABEL == 'ci-runtime' }}
run: /usr/local/tgi/bin/tgi_trtllm_backend_tests

View File

@ -42,6 +42,7 @@ jobs:
permissions: permissions:
contents: write contents: write
packages: write packages: write
id-token: write
with: with:
hardware: ${{ matrix.hardware }} hardware: ${{ matrix.hardware }}
# https://github.com/actions/runner/issues/2206 # https://github.com/actions/runner/issues/2206

View File

@ -31,7 +31,7 @@ jobs:
with: with:
# Released on: 02 May, 2024 # Released on: 02 May, 2024
# https://releases.rs/docs/1.78.0/ # https://releases.rs/docs/1.78.0/
toolchain: 1.80.0 toolchain: 1.84.0
override: true override: true
components: rustfmt, clippy components: rustfmt, clippy
- name: Install Protoc - name: Install Protoc
@ -44,10 +44,14 @@ jobs:
run: | run: |
sudo apt update sudo apt update
sudo apt install python3.11-dev -y sudo apt install python3.11-dev -y
pip install -U pip uv
uv venv
source ./.venv/bin/activate
make install-cpu make install-cpu
- name: Run server tests - name: Run server tests
run: | run: |
pip install pytest source ./.venv/bin/activate
uv pip install pytest
export HF_TOKEN=${{ secrets.HF_TOKEN }} export HF_TOKEN=${{ secrets.HF_TOKEN }}
pytest -s -vv server/tests pytest -s -vv server/tests
- name: Pre-commit checks - name: Pre-commit checks

35
Cargo.lock generated
View File

@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo. # This file is automatically @generated by Cargo.
# It is not intended for manual editing. # It is not intended for manual editing.
version = 3 version = 4
[[package]] [[package]]
name = "addr2line" name = "addr2line"
@ -1544,7 +1544,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
dependencies = [ dependencies = [
"ahash", "ahash",
"allocator-api2",
] ]
[[package]] [[package]]
@ -2187,9 +2186,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.164" version = "0.2.169"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a"
[[package]] [[package]]
name = "libfuzzer-sys" name = "libfuzzer-sys"
@ -4424,14 +4423,14 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-backends-trtllm" name = "text-generation-backends-trtllm"
version = "3.0.2-dev0" version = "3.1.1-dev0"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"clap 4.5.21", "clap 4.5.21",
"cmake", "cmake",
"cxx", "cxx",
"cxx-build", "cxx-build",
"hashbrown 0.14.5", "hashbrown 0.15.1",
"hf-hub", "hf-hub",
"pkg-config", "pkg-config",
"pyo3", "pyo3",
@ -4445,7 +4444,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-benchmark" name = "text-generation-benchmark"
version = "3.0.2-dev0" version = "3.1.1-dev0"
dependencies = [ dependencies = [
"average", "average",
"clap 4.5.21", "clap 4.5.21",
@ -4465,7 +4464,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "3.0.2-dev0" version = "3.1.1-dev0"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"base64 0.22.1", "base64 0.22.1",
@ -4483,7 +4482,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "3.0.2-dev0" version = "3.1.1-dev0"
dependencies = [ dependencies = [
"clap 4.5.21", "clap 4.5.21",
"ctrlc", "ctrlc",
@ -4504,7 +4503,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router" name = "text-generation-router"
version = "3.0.2-dev0" version = "3.1.1-dev0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-stream", "async-stream",
@ -4555,7 +4554,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router-v2" name = "text-generation-router-v2"
version = "3.0.2-dev0" version = "3.1.1-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
@ -4604,7 +4603,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router-v3" name = "text-generation-router-v3"
version = "3.0.2-dev0" version = "3.1.1-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
@ -4791,9 +4790,9 @@ dependencies = [
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.41.1" version = "1.43.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33" checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e"
dependencies = [ dependencies = [
"backtrace", "backtrace",
"bytes", "bytes",
@ -4819,9 +4818,9 @@ dependencies = [
[[package]] [[package]]
name = "tokio-macros" name = "tokio-macros"
version = "2.4.0" version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -4862,9 +4861,9 @@ dependencies = [
[[package]] [[package]]
name = "tokio-stream" name = "tokio-stream"
version = "0.1.16" version = "0.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f4e6ce100d0eb49a2734f8c0812bcd324cf357d21810932c5df6b96ef2b86f1" checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"pin-project-lite", "pin-project-lite",

View File

@ -20,7 +20,7 @@ default-members = [
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "3.0.2-dev0" version = "3.1.1-dev0"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"] authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference" homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -1,5 +1,5 @@
# Rust builder # Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.84.0 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -47,7 +47,7 @@ RUN cargo build --profile release-opt --frozen
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099 # NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
ARG PYTORCH_VERSION=2.4.0 ARG PYTORCH_VERSION=2.5.1
ARG PYTHON_VERSION=3.11 ARG PYTHON_VERSION=3.11
# Keep in sync with `server/pyproject.toml # Keep in sync with `server/pyproject.toml
@ -58,7 +58,7 @@ ARG INSTALL_CHANNEL=pytorch
# Automatically set by buildx # Automatically set by buildx
ARG TARGETPLATFORM ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH ENV PATH=/opt/conda/bin:$PATH
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \ build-essential \
@ -224,17 +224,19 @@ COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-
COPY --from=flashinfer-builder /opt/conda/lib/python3.11/site-packages/flashinfer/ /opt/conda/lib/python3.11/site-packages/flashinfer/ COPY --from=flashinfer-builder /opt/conda/lib/python3.11/site-packages/flashinfer/ /opt/conda/lib/python3.11/site-packages/flashinfer/
# Install flash-attention dependencies # Install flash-attention dependencies
RUN pip install einops --no-cache-dir # RUN pip install einops --no-cache-dir
# Install server # Install server
COPY proto proto COPY proto proto
COPY server server COPY server server
COPY server/Makefile server/Makefile COPY server/Makefile server/Makefile
ENV UV_SYSTEM_PYTHON=1
RUN cd server && \ RUN cd server && \
make gen-server && \ make gen-server && \
pip install -r requirements_cuda.txt && \ python -c "from text_generation_server.pb import generate_pb2" && \
pip install ".[attention, bnb, accelerate, compressed-tensors, marlin, moe, quantize, peft, outlines]" --no-cache-dir && \ pip install -U pip uv && \
pip install nvidia-nccl-cu12==2.22.3 uv pip install -e ".[attention, bnb, accelerate, compressed-tensors, marlin, moe, quantize, peft, outlines]" --no-cache-dir # && \
# uv pip install nvidia-nccl-cu12==2.22.3
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2 ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
# Required to find libpython within the rust binaries # Required to find libpython within the rust binaries

View File

@ -1,5 +1,5 @@
# Rust builder # Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.84.0 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -104,7 +104,7 @@ RUN case ${TARGETPLATFORM} in \
/opt/conda/bin/conda clean -ya /opt/conda/bin/conda clean -ya
# Install flash-attention, torch dependencies # Install flash-attention, torch dependencies
RUN python3 -m pip install --upgrade pip && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/* RUN python3 -m pip install --upgrade pip uv && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/*
RUN conda install mkl=2021 RUN conda install mkl=2021
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/
@ -268,9 +268,18 @@ COPY server/exllamav2_kernels/ .
RUN python setup.py build RUN python setup.py build
FROM kernel-builder AS marlin-kernels
WORKDIR /usr/src
ENV MARLIN_KERNELS_BRANCH=v0.3.6
ENV VLLM_TARGET_DEVICE=rocm
RUN git clone https://github.com/danieldk/marlin-kernels.git && \
cd marlin-kernels && \
git checkout ${MARLIN_KERNELS_BRANCH} && \
python setup.py install
FROM kernel-builder AS moe-kernels FROM kernel-builder AS moe-kernels
WORKDIR /usr/src WORKDIR /usr/src
ENV MOE_KERNELS_BRANCH=a67b35841774b2056a73806c36661134b5054edd ENV MOE_KERNELS_BRANCH=v0.8.2
ENV VLLM_TARGET_DEVICE=rocm ENV VLLM_TARGET_DEVICE=rocm
RUN git clone https://github.com/danieldk/moe-kernels.git && \ RUN git clone https://github.com/danieldk/moe-kernels.git && \
cd moe-kernels && \ cd moe-kernels && \
@ -299,6 +308,9 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311
# Copy build artifacts from exllamav2 kernels builder # Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from marlin kernels
COPY --from=marlin-kernels /usr/src/marlin-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from moe kernels # Copy build artifacts from moe kernels
COPY --from=moe-kernels /usr/src/moe-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages COPY --from=moe-kernels /usr/src/moe-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
@ -306,10 +318,11 @@ COPY --from=moe-kernels /usr/src/moe-kernels/build/lib.linux-x86_64-cpython-311
COPY proto proto COPY proto proto
COPY server server COPY server server
COPY server/Makefile server/Makefile COPY server/Makefile server/Makefile
ENV UV_SYSTEM_PYTHON=1
RUN cd server && \ RUN cd server && \
make gen-server && \ make gen-server && \
pip install -r requirements_rocm.txt && \ pip install -U pip uv && \
pip install ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark

View File

@ -1,6 +1,6 @@
ARG PLATFORM=xpu ARG PLATFORM=xpu
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.84.0 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -108,17 +108,19 @@ RUN pip install triton-xpu==3.0.0b2 --no-cache-dir
COPY proto proto COPY proto proto
COPY server server COPY server server
COPY server/Makefile server/Makefile COPY server/Makefile server/Makefile
ENV UV_SYSTEM_PYTHON=1
RUN cd server && \ RUN cd server && \
make gen-server && \ make gen-server && \
pip install -r requirements_intel.txt && \ pip install -U pip uv && \
pip install ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/pti/0.9/lib:/opt/conda/lib ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/pti/0.9/lib:/opt/conda/lib
ENV CCL_ZE_IPC_EXCHANGE=sockets ENV CCL_ZE_IPC_EXCHANGE=sockets
#ENV TORCH_LLM_ALLREDUCE=1 #ENV TORCH_LLM_ALLREDUCE=1
#ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0 #ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 033af6f63745ac748cccdadee5c6140c7971edf6 RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 1ccf72b2d11cd00b47aef6d6cd054c088aa6f083
RUN cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc,ats-m150' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch RUN cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc,ats-m150' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
# Install benchmarker # Install benchmarker
@ -211,10 +213,11 @@ ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
COPY proto proto COPY proto proto
COPY server server COPY server server
COPY server/Makefile server/Makefile COPY server/Makefile server/Makefile
ENV UV_SYSTEM_PYTHON=1
RUN cd server && \ RUN cd server && \
make gen-server && \ make gen-server && \
pip install -r requirements_intel.txt && \ pip install -U pip uv && \
pip install ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
@ -224,9 +227,9 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
FROM ${PLATFORM} AS final FROM ${PLATFORM} AS final
ENV ATTENTION=paged ENV ATTENTION=flashdecoding-ipex
ENV PREFIX_CACHING=0 ENV PREFIX_CACHING=1
ENV PREFILL_CHUNKING=0 ENV PREFILL_CHUNKING=1
ENV CUDA_GRAPHS=0 ENV CUDA_GRAPHS=0
ENTRYPOINT ["text-generation-launcher"] ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"] CMD ["--json-output"]

View File

@ -1,20 +1,14 @@
ARG CUDA_ARCH_LIST="75-real;80-real;86-real;89-real;90-real" ARG cuda_arch_list="75-real;80-real;86-real;89-real;90-real"
ARG OMPI_VERSION="4.1.7rc1" ARG build_type=release
ARG ompi_version=4.1.7
# Build dependencies resolver stage ARG sccache_gha_enabled=off
FROM lukemathwalker/cargo-chef:latest AS chef ARG actions_cache_url=""
WORKDIR /usr/src/text-generation-inference/backends/trtllm ARG actions_runtime_token=""
FROM chef AS planner
COPY . .
RUN cargo chef prepare --recipe-path recipe.json
# CUDA dependent dependencies resolver stage # CUDA dependent dependencies resolver stage
FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu24.04 AS cuda-builder FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu24.04 AS cuda-builder
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt update && apt install -y \
build-essential \ build-essential \
cmake \ cmake \
curl \ curl \
@ -22,8 +16,11 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
g++-14 \ g++-14 \
git \ git \
git-lfs \ git-lfs \
lld \
libssl-dev \ libssl-dev \
libucx-dev \ libucx-dev \
libasan8 \
libubsan1 \
ninja-build \ ninja-build \
pkg-config \ pkg-config \
pipx \ pipx \
@ -31,7 +28,7 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
python3-dev \ python3-dev \
python3-setuptools \ python3-setuptools \
tar \ tar \
wget && \ wget --no-install-recommends && \
pipx ensurepath pipx ensurepath
ENV TGI_INSTALL_PREFIX=/usr/local/tgi ENV TGI_INSTALL_PREFIX=/usr/local/tgi
@ -39,17 +36,19 @@ ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt
# Install OpenMPI # Install OpenMPI
FROM cuda-builder AS mpi-builder FROM cuda-builder AS mpi-builder
ARG OMPI_VERSION WORKDIR /opt/src/mpi
ENV OMPI_TARBALL_FILENAME="openmpi-$OMPI_VERSION.tar.bz2" ARG ompi_version
RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILENAME" -P /opt/src && \ ENV OMPI_VERSION=${ompi_version}
mkdir /usr/src/mpi && \ ENV OMPI_TARBALL_FILENAME=openmpi-${OMPI_VERSION}.tar.bz2
tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \ ADD --checksum=sha256:54a33cb7ad81ff0976f15a6cc8003c3922f0f3d8ceed14e1813ef3603f22cd34 \
cd /usr/src/mpi && \ https://download.open-mpi.org/release/open-mpi/v4.1/${OMPI_TARBALL_FILENAME} .
RUN tar --strip-components=1 -xf ${OMPI_TARBALL_FILENAME} &&\
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \ ./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \
make -j all && \ make -j all && \
make install && \ make install && \
rm -rf "/opt/src/$OMPI_TARBALL_FILENAME" rm -rf ${OMPI_TARBALL_FILENAME}/..
# Install TensorRT # Install TensorRT
FROM cuda-builder AS trt-builder FROM cuda-builder AS trt-builder
@ -61,30 +60,50 @@ RUN chmod +x /opt/install_tensorrt.sh && \
FROM cuda-builder AS tgi-builder FROM cuda-builder AS tgi-builder
WORKDIR /usr/src/text-generation-inference WORKDIR /usr/src/text-generation-inference
# Scoped global args reuse
ARG cuda_arch_list
ARG build_type
ARG sccache_gha_enabled
ARG actions_cache_url
ARG actions_runtime_token
# Install Rust # Install Rust
ENV PATH="/root/.cargo/bin:$PATH"
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \ RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \
chmod -R a+w /root/.rustup && \ chmod -R a+w /root/.rustup && \
chmod -R a+w /root/.cargo chmod -R a+w /root/.cargo && \
cargo install sccache --locked
ENV PATH="/root/.cargo/bin:$PATH"
RUN cargo install cargo-chef
# Cache dependencies
COPY --from=planner /usr/src/text-generation-inference/backends/trtllm/recipe.json .
RUN cargo chef cook --release --recipe-path recipe.json
# Build actual TGI
ARG CUDA_ARCH_LIST
ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt:$CMAKE_PREFIX_PATH"
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH" ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig:$PKG_CONFIG_PATH" ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig"
ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt"
COPY . . ENV USE_LLD_LINKER=ON
ENV CUDA_ARCH_LIST=${cuda_arch_list}
# SCCACHE Specifics args - before finding a better, more generic, way...
ENV SCCACHE_GHA_ENABLED=${sccache_gha_enabled}
ENV ACTIONS_CACHE_URL=${actions_cache_url}
ENV ACTIONS_RUNTIME_TOKEN=${actions_runtime_token}
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY router router
COPY backends backends
COPY benchmark benchmark
COPY launcher launcher
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \
cd backends/trtllm && \ ENV RUSTC_WRAPPER=sccache
CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release ENV CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX
RUN export CMAKE_C_COMPILER_LAUNCHER=sccache && \
export CMAKE_CXX_COMPILER_LAUNCHER=sccache && \
export CMAKE_CUDA_COMPILER_LAUNCHER=sccache && \
mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \
cargo build --profile ${build_type} --package text-generation-backends-trtllm --bin text-generation-backends-trtllm && \
sccache --show-stats
FROM nvidia/cuda:12.6.3-cudnn-runtime-ubuntu24.04 AS runtime FROM nvidia/cuda:12.6.3-cudnn-runtime-ubuntu24.04 AS runtime
RUN apt update && apt install -y libucx0 pipx python3-minimal python3-dev python3-pip python3-venv && \ RUN apt update && apt install -y libucx0 pipx python3-minimal python3-dev python3-pip python3-venv && \
@ -104,10 +123,33 @@ COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
# This is used only for the CI/CD
FROM nvidia/cuda:12.6.3-cudnn-runtime-ubuntu24.04 AS ci-runtime
RUN apt update && apt install -y libasan8 libubsan1 libucx0 pipx python3-minimal python3-dev python3-pip python3-venv && \
rm -rf /var/lib/{apt,dpkg,cache,log}/ && \
pipx ensurepath && \
pipx install --include-deps transformers tokenizers
WORKDIR /usr/local/tgi/bin
ENV PATH=/root/.local/share/pipx/venvs/transformers/bin/:$PATH
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/mpi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
ENV TOKENIZERS_PARALLELISM=false
ENV OMPI_MCA_plm_rsh_agent=""
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
# Basically we copy from target/debug instead of target/release
COPY --from=tgi-builder /usr/src/text-generation-inference/target/debug/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
# This is the final image
FROM runtime FROM runtime
LABEL co.huggingface.vendor="Hugging Face Inc." LABEL co.huggingface.vendor="Hugging Face Inc."
LABEL org.opencontainers.image.authors="hardware@hf.co" LABEL org.opencontainers.image.authors="hardware@hf.co"
LABEL org.opencontainers.title="Text-Generation-Inference TensorRT-LLM Backend"
ENTRYPOINT ["./text-generation-launcher"] ENTRYPOINT ["./text-generation-launcher"]
CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"] CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"]

View File

@ -84,7 +84,7 @@ model=HuggingFaceH4/zephyr-7b-beta
volume=$PWD/data volume=$PWD/data
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
3.0.0 ghcr.io/huggingface/text-generation-inference:3.0.0 --model-id $model ghcr.io/huggingface/text-generation-inference:3.1.0 --model-id $model
``` ```
And then you can make requests like And then you can make requests like
@ -121,7 +121,7 @@ curl localhost:8080/v1/chat/completions \
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. **Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.0-rocm --model-id $model` instead of the command above. **Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.0-rocm --model-id $model` instead of the command above.
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli): To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
``` ```
@ -151,7 +151,8 @@ model=meta-llama/Meta-Llama-3.1-8B-Instruct
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
token=<your cli READ token> token=<your cli READ token>
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.0 --model-id $model docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.1.0 --model-id $model
``` ```
### A note on Shared Memory (shm) ### A note on Shared Memory (shm)

View File

@ -8,7 +8,7 @@ use tracing_opentelemetry::OpenTelemetrySpanExt;
/// Inject context in the metadata of a gRPC request. /// Inject context in the metadata of a gRPC request.
struct MetadataInjector<'a>(pub &'a mut tonic::metadata::MetadataMap); struct MetadataInjector<'a>(pub &'a mut tonic::metadata::MetadataMap);
impl<'a> Injector for MetadataInjector<'a> { impl Injector for MetadataInjector<'_> {
/// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs /// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs
fn set(&mut self, key: &str, value: String) { fn set(&mut self, key: &str, value: String) {
if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) { if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) {

View File

@ -1,13 +1,5 @@
cmake_minimum_required(VERSION 3.20) cmake_minimum_required(VERSION 3.20)
if (NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER AND CMAKE_BUILD_TYPE STREQUAL "Debug")
find_program(CCACHE_EXECUTABLE "ccache")
if (CCACHE_EXECUTABLE)
message(STATUS "Using ccache")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_EXECUTABLE}" CACHE PATH "Path to ccache" FORCE)
endif ()
endif ()
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0") if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
cmake_policy(SET CMP0135 NEW) cmake_policy(SET CMP0135 NEW)
endif () endif ()
@ -21,6 +13,7 @@ include(CheckCXXCompilerFlag)
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF) option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF) option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
option(TGI_TRTLLM_BACKEND_BUILD_USE_LLD "Enable lld linker instead of ld" OFF)
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support") set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located") set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located")
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located") set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located")
@ -28,20 +21,22 @@ set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE ST
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features # We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml) find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
find_package(MPI REQUIRED)
#### External dependencies #### #### External dependencies ####
include(cmake/json.cmake) include(cmake/json.cmake)
include(cmake/spdlog.cmake) include(cmake/spdlog.cmake)
include(cmake/trtllm.cmake) include(cmake/trtllm.cmake)
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(TGI_TRTLLM_BACKEND_DEBUG ON)
add_compile_definitions(TGI_TRTLLM_BACKEND_DEBUG=1) add_compile_definitions(TGI_TRTLLM_BACKEND_DEBUG=1)
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_TRACE)
endif () endif ()
# This attempt to detect if the compiler can emit warning if it can't apply return value optimization from a function if (${TGI_TRTLLM_BACKEND_BUILD_USE_LLD})
check_cxx_compiler_flag("-Wnrvo" COMPILER_SUPPORT_WARNING_ON_NVRO) message(STATUS "Using lld linker")
if(${COMPILER_SUPPORT_WARNING_ON_NVRO}) add_link_options("-fuse-ld=lld")
set(CMAKE_CXX_FLAGS "{CMAKE_CXX_FLAGS} -Wnvro")
endif () endif ()
# Let's build TRTLLM as part of CMake # Let's build TRTLLM as part of CMake
@ -60,46 +55,63 @@ target_include_directories(tgi_trtllm_backend_impl PRIVATE
target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include") target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
target_link_libraries(tgi_trtllm_backend_impl PRIVATE CUDA::cudart CUDA::nvml) target_link_libraries(tgi_trtllm_backend_impl PRIVATE CUDA::cudart CUDA::nvml)
target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog) target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog)
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper)
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug")
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm)
else()
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapperm)
endif ()
# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back # This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back
install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker) install(TARGETS tgi_trtllm_backend_impl)
install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB) install(TARGETS tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker)
install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} TYPE LIB)
if (NOT ${TGI_TRTLLM_BACKEND_DEBUG})
install(FILES ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB)
endif ()
#### Unit Tests #### #### Unit Tests ####
if (${TGI_TRTLLM_BACKEND_BUILD_TESTS}) if (${TGI_TRTLLM_BACKEND_BUILD_TESTS} AND CMAKE_BUILD_TYPE MATCHES "Debug")
message(STATUS "Building tests") message(STATUS "Building tests")
option(TGI_TRTLLM_BACKEND_ENABLE_ASAN "Enable AddressSanitizer")
option(TGI_TRTLLM_BACKEND_ENABLE_UBSAN "Enable UndefinedSanitizer")
FetchContent_Declare( FetchContent_Declare(
Catch2 Catch2
URL https://github.com/catchorg/Catch2/archive/refs/tags/v3.7.1.tar.gz URL https://github.com/catchorg/Catch2/archive/refs/tags/v3.7.1.tar.gz
) )
FetchContent_MakeAvailable(Catch2) FetchContent_MakeAvailable(Catch2)
# This attempt to detect if the compiler can emit warning if it can't apply return value optimization from a function
check_cxx_compiler_flag("-Wnrvo" COMPILER_SUPPORT_WARNING_ON_NVRO)
if (${COMPILER_SUPPORT_WARNING_ON_NVRO})
message(STATUS "Enabling non-NVRO detection")
target_compile_options(tgi_trtllm_backend_impl "-Wnvro")
endif ()
cmake_path(GET TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH PARENT_PATH TRTLLM_NVRTC_WRAPPER_PARENT_LIBRARY_PATH)
message(STATUS "Adding linking path: ${TRTLLM_NVRTC_WRAPPER_PARENT_LIBRARY_PATH}")
add_executable(tgi_trtllm_backend_tests tests/test_hardware.cpp tests/test_backend.cpp) add_executable(tgi_trtllm_backend_tests tests/test_hardware.cpp tests/test_backend.cpp)
# target_compile_options(tgi_trtllm_backend_tests PRIVATE -Werror)
target_link_directories(tgi_trtllm_backend_tests PRIVATE "${TRTLLM_NVRTC_WRAPPER_PARENT_LIBRARY_PATH}")
target_include_directories(tgi_trtllm_backend_tests PUBLIC "${trtllm_SOURCE_DIR}/cpp/include") target_include_directories(tgi_trtllm_backend_tests PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
target_include_directories(tgi_trtllm_backend_tests PUBLIC "csrc/") target_include_directories(tgi_trtllm_backend_tests PUBLIC "csrc/")
target_link_libraries(tgi_trtllm_backend_tests PRIVATE ${TRTLLM_LIBS} CUDA::cudart CUDA::nvml) target_link_libraries(tgi_trtllm_backend_tests PRIVATE ${TRTLLM_LIBS} CUDA::cudart CUDA::nvml)
target_link_libraries(tgi_trtllm_backend_tests PUBLIC Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog tgi_trtllm_backend_impl) target_link_libraries(tgi_trtllm_backend_tests PUBLIC Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog tgi_trtllm_backend_impl)
target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper)
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") if (${TGI_TRTLLM_BACKEND_ENABLE_ASAN})
target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm) message(STATUS "Enabled AddressSanitizer")
else() target_link_options(tgi_trtllm_backend_tests BEFORE PUBLIC -fsanitize=address)
target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapperm)
endif () endif ()
if(CMAKE_BUILD_TYPE MATCHES "Debug") if (${TGI_TRTLLM_BACKEND_ENABLE_UBSAN})
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror -fsanitize=undefined -fsanitize=address") message(STATUS "Enabled UndefinedSanitizer")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -fsanitize=undefined -fsanitize=address") target_link_options(tgi_trtllm_backend_tests BEFORE PUBLIC -fsanitize=undefined)
target_link_options(tgi_trtllm_backend_tests BEFORE PUBLIC -fsanitize=undefined PUBLIC -fsanitize=address)
endif () endif ()
list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras) install(TARGETS tgi_trtllm_backend_tests)
include(CTest)
include(Catch) # list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
catch_discover_tests(tgi_trtllm_backend_tests) # include(CTest)
# include(Catch)
# catch_discover_tests(tgi_trtllm_backend_tests)
endif () endif ()

View File

@ -7,20 +7,16 @@ homepage.workspace = true
[dependencies] [dependencies]
async-trait = "0.1" async-trait = "0.1"
#async-stream = "0.3"
clap = { version = "4.5", features = ["derive"] } clap = { version = "4.5", features = ["derive"] }
cxx = "1.0" cxx = "1.0"
hashbrown = "0.14" hashbrown = "0.15"
hf-hub = { workspace = true } hf-hub = { workspace = true }
#log = { version = "0.4", features = [] }
text-generation-router = { path = "../../router" } text-generation-router = { path = "../../router" }
tokenizers = { workspace = true } tokenizers = { workspace = true }
tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.43.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.15" tokio-stream = "0.1.17"
thiserror = "1.0.63" thiserror = "1.0.63"
tracing = "0.1" tracing = "0.1"
#tracing-opentelemetry = "0.25"
#tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
pyo3 = { workspace = true } pyo3 = { workspace = true }
[build-dependencies] [build-dependencies]

View File

@ -3,6 +3,7 @@ use pkg_config;
use std::env; use std::env;
use std::env::consts::ARCH; use std::env::consts::ARCH;
use std::path::{absolute, PathBuf}; use std::path::{absolute, PathBuf};
use std::sync::LazyLock;
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 1] = ["spdlog"]; const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 1] = ["spdlog"];
const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST"); const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST");
@ -12,12 +13,20 @@ const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX");
const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR"); const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR");
const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR"); const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR");
const IS_GHA_BUILD: LazyLock<bool> = LazyLock::new(|| {
option_env!("SCCACHE_GHA_ENABLED").map_or(false, |value| match value.to_lowercase().as_str() {
"on" => true,
"true" => true,
"1" => true,
_ => false,
})
});
// Dependencies // Dependencies
const BACKEND_DEPS: [&str; 2] = ["tgi_trtllm_backend_impl", "tgi_trtllm_backend"]; const BACKEND_DEPS: &str = "tgi_trtllm_backend_impl";
const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"]; const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"];
const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [ const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 4] = [
("dylib", "tensorrt_llm"), ("dylib", "tensorrt_llm"),
("static", "tensorrt_llm_executor_static"),
("dylib", "tensorrt_llm_nvrtc_wrapper"), ("dylib", "tensorrt_llm_nvrtc_wrapper"),
("dylib", "nvinfer_plugin_tensorrt_llm"), ("dylib", "nvinfer_plugin_tensorrt_llm"),
("dylib", "decoder_attention"), ("dylib", "decoder_attention"),
@ -32,6 +41,48 @@ macro_rules! probe {
}; };
} }
fn get_compiler_flag(
switch: bool,
true_case: &'static str,
false_case: &'static str,
) -> &'static str {
match switch {
true => true_case,
false => false_case,
}
}
fn get_library_architecture() -> &'static str {
let os = env::var("CARGO_CFG_TARGET_OS").unwrap();
let arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap();
let env = env::var("CARGO_CFG_TARGET_ENV").unwrap();
match os.as_str() {
"linux" => {
if env != "gnu" {
panic!("unsupported linux ABI {env}, only 'gnu' is supported")
}
match arch.as_str() {
"x86_64" => "x86_64-linux-gnu",
"aarch64" => "aarch64-linux-gnu",
_ => panic!("unsupported linux architecture {arch}"),
}
}
"windows" => {
if env != "msvc" {
panic!("unsupported windows ABI {env}, only 'msvc' is supported")
}
match arch.as_str() {
"x86_64" => "x86_64-windows-msvc",
_ => panic!("unsupported windows architecture {arch}"),
}
}
_ => panic!("unsupported OS {os}"),
}
}
fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) { fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) {
// Build the backend implementation through CMake // Build the backend implementation through CMake
let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi"); let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
@ -54,10 +105,44 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
.env("OPT_LEVEL", opt_level) .env("OPT_LEVEL", opt_level)
.define("CMAKE_INSTALL_PREFIX", &install_path) .define("CMAKE_INSTALL_PREFIX", &install_path)
.define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc") .define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc")
.define("Python3_ROOT_DIR", "../venv") .define("CMAKE_LIBRARY_ARCHITECTURE", get_library_architecture())
.define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list) .define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list)
.define(
"TGI_TRTLLM_BACKEND_DEBUG",
get_compiler_flag(is_debug, "ON", "OFF"),
)
.define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path); .define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path);
if is_debug || *IS_GHA_BUILD {
config.define("TGI_TRTLLM_BACKEND_BUILD_TESTS", "ON");
}
if option_env!("USE_LLD_LINKER").is_some() {
println!("cargo:warning=Using lld linker");
config.define("TGI_TRTLLM_BACKEND_BUILD_USE_LLD", "ON");
}
if (is_debug && option_env!("ENABLE_ASAN").is_some()) || *IS_GHA_BUILD {
println!("cargo:warning=Enabling Address Sanitizer");
config.define("TGI_TRTLLM_BACKEND_ENABLE_ASAN", "ON");
}
if (is_debug && option_env!("ENABLE_UBSAN").is_some()) || *IS_GHA_BUILD {
println!("cargo:warning=Enabling Undefined Sanitizer");
config.define("TGI_TRTLLM_BACKEND_ENABLE_UBSAN", "ON");
}
if let Some(nvcc_host_compiler) = option_env!("CMAKE_CUDA_HOST_COMPILER") {
config.define("CMAKE_CUDA_HOST_COMPILER", nvcc_host_compiler);
}
if let Some(wrapper) = option_env!("RUSTC_WRAPPER") {
println!("cargo:warning=Using caching tool: {wrapper}");
config.define("CMAKE_C_COMPILER_LAUNCHER", wrapper);
config.define("CMAKE_CXX_COMPILER_LAUNCHER", wrapper);
config.define("CMAKE_CUDA_COMPILER_LAUNCHER", wrapper);
}
// Allow to override which Python to use ... // Allow to override which Python to use ...
if let Some(python3) = option_env!("Python3_EXECUTABLE") { if let Some(python3) = option_env!("Python3_EXECUTABLE") {
config.define("Python3_EXECUTABLE", python3); config.define("Python3_EXECUTABLE", python3);
@ -78,23 +163,18 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
} }
// Emit linkage information from the artifacts we just built // Emit linkage information from the artifacts we just built
let install_lib_path = install_path.join("lib"); for path in ["lib", "lib64"] {
let install_lib_path = install_path.join(path);
println!( println!(
r"cargo:warning=Adding link search path: {}", r"cargo:warning=Adding link search path: {}",
install_lib_path.display() install_lib_path.display()
); );
println!(r"cargo:rustc-link-search={}", install_lib_path.display()); println!(r"cargo:rustc-link-search={}", install_lib_path.display());
}
(PathBuf::from(install_path), deps_folder) (PathBuf::from(install_path), deps_folder)
} }
fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) { fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) {
let ndebug = match is_debug {
true => "1",
false => "0",
};
CFG.include_prefix = "backends/trtllm"; CFG.include_prefix = "backends/trtllm";
cxx_build::bridge("src/lib.rs") cxx_build::bridge("src/lib.rs")
.static_flag(true) .static_flag(true)
@ -106,7 +186,10 @@ fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) {
.include("/usr/local/tensorrt/include") .include("/usr/local/tensorrt/include")
.include("csrc/") .include("csrc/")
.file("csrc/ffi.hpp") .file("csrc/ffi.hpp")
.define("TGI_TRTLLM_BACKEND_DEBUG", ndebug) .define(
"TGI_TRTLLM_BACKEND_DEBUG",
get_compiler_flag(is_debug, "ON", "OFF"),
)
.compile("tgi_trtllm_backend"); .compile("tgi_trtllm_backend");
println!("cargo:rerun-if-changed=CMakeLists.txt"); println!("cargo:rerun-if-changed=CMakeLists.txt");
@ -125,6 +208,7 @@ fn main() {
let build_profile = env::var("PROFILE").unwrap(); let build_profile = env::var("PROFILE").unwrap();
let (is_debug, opt_level) = match build_profile.as_ref() { let (is_debug, opt_level) = match build_profile.as_ref() {
"debug" => (true, "0"), "debug" => (true, "0"),
"dev" => (true, "0"),
_ => (false, "3"), _ => (false, "3"),
}; };
@ -161,7 +245,5 @@ fn main() {
}); });
// Backend // Backend
BACKEND_DEPS.iter().for_each(|name| { println!("cargo:rustc-link-lib=static={}", &BACKEND_DEPS);
println!("cargo:rustc-link-lib=static={}", name);
});
} }

View File

@ -4,14 +4,14 @@ set(SPDLOG_FMT_EXTERNAL OFF)
# Define the level at which SPDLOG_ compilation level is defined # Define the level at which SPDLOG_ compilation level is defined
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG) add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_TRACE)
else () else ()
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO) add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG)
endif () endif ()
fetchcontent_declare( fetchcontent_declare(
spdlog spdlog
# DOWNLOAD_EXTRACT_TIMESTAMP # DOWNLOAD_EXTRACT_TIMESTAMP
URL https://github.com/gabime/spdlog/archive/refs/tags/v1.14.1.tar.gz URL https://github.com/gabime/spdlog/archive/refs/tags/v1.15.0.tar.gz
) )
fetchcontent_makeavailable(spdlog) fetchcontent_makeavailable(spdlog)

View File

@ -14,19 +14,21 @@ message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}")
set(ENABLE_UCX OFF) set(ENABLE_UCX OFF)
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
set(FAST_BUILD ON) set(FAST_BUILD ON)
set(NVTX_DISABLE OFF) set(NVTX_DISABLE ON)
set(INDEX_RANGE_CHECK ON)
else () else ()
set(FAST_BUILD OFF) set(FAST_BUILD OFF)
set(FAST_MATH ON) set(FAST_MATH ON)
set(NVTX_DISABLE ON) set(NVTX_DISABLE OFF)
set(INDEX_RANGE_CHECK OFF)
endif () endif ()
find_package(Python3 REQUIRED Interpreter) find_package(Python3 REQUIRED Interpreter)
fetchcontent_declare( fetchcontent_declare(
trtllm trtllm
GIT_REPOSITORY https://github.com/huggingface/TensorRT-LLM.git GIT_REPOSITORY https://github.com/nvidia/TensorRT-LLM.git
GIT_TAG 1bb9ca4688805444f203647674bac1d7219d0579 GIT_TAG v0.16.0
GIT_SHALLOW ON GIT_SHALLOW ON
DOWNLOAD_EXTRACT_TIMESTAMP DOWNLOAD_EXTRACT_TIMESTAMP
) )

View File

@ -1,7 +1,6 @@
#include <ranges> #include <ranges>
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#include <spdlog/spdlog.h>
#include "backend.hpp" #include "backend.hpp"
#include "hardware.hpp" #include "hardware.hpp"
@ -17,7 +16,8 @@ namespace huggingface::tgi::backends::trtllm {
if (world_size > 1) { if (world_size > 1) {
SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode"); SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
mode = tle::CommunicationMode::kORCHESTRATOR; mode = tle::CommunicationMode::kORCHESTRATOR;
orchestratorConfig = std::make_optional<tle::OrchestratorConfig>(true, executor_worker_path_, nullptr, true); orchestratorConfig = std::make_optional<tle::OrchestratorConfig>(true, executor_worker_path_, nullptr,
true);
} else { } else {
SPDLOG_INFO("Detected single engine deployment, using leader mode"); SPDLOG_INFO("Detected single engine deployment, using leader mode");
} }
@ -51,13 +51,14 @@ namespace huggingface::tgi::backends::trtllm {
} }
std::expected<request_id_t, backend_error_t> std::expected<request_id_t, backend_error_t>
backend_t::submit(std::span<const token_id_t> token_ids, const generation_params_t generation_params, const sampling_params_t sampling_params) noexcept { backend_t::submit(std::span<const token_id_t> token_ids, const generation_params_t g_params,
SPDLOG_DEBUG("Submitting {:d} tokens to the executor for scheduling ({}, {})", token_ids.size(), generation_params, sampling_params); const sampling_params_t s_params) noexcept {
SPDLOG_DEBUG("Submit {:d} tokens for scheduling ({}, {})", token_ids.size(), g_params, s_params);
return executor_.enqueueRequest(tle::Request{ return executor_.enqueueRequest(tle::Request{
{token_ids.begin(), token_ids.end()}, // Making actual copy of the tokens {token_ids.begin(), token_ids.end()}, // Making actual copy of the tokens
static_cast<tle::SizeType32>(generation_params.max_new_tokens), static_cast<tle::SizeType32>(g_params.max_new_tokens),
true, true,
(tle::SamplingConfig) sampling_params, (tle::SamplingConfig) s_params,
tle::OutputConfig{ /* returnLogProbs= */ true}, tle::OutputConfig{ /* returnLogProbs= */ true},
std::nullopt, std::nullopt,
std::nullopt, std::nullopt,

View File

@ -28,9 +28,53 @@ namespace huggingface::tgi::backends::trtllm {
#include "backends/trtllm/src/lib.rs.h" #include "backends/trtllm/src/lib.rs.h"
namespace huggingface::tgi::backends::trtllm { namespace huggingface::tgi::backends::trtllm {
std::once_flag backend_initialized_flag; std::once_flag backend_initialized_flag;
constexpr finish_reason_t as_finish_reason_t(const tle::FinishReason reason) noexcept {
switch (reason) {
case tle::FinishReason::kNOT_FINISHED:
return finish_reason_t::kNOT_FINISHED;
case tle::FinishReason::kSTOP_WORDS:
return finish_reason_t::kSTOP_WORDS;
case tle::FinishReason::kEND_ID:
return finish_reason_t::kEND_ID;
case tle::FinishReason::kLENGTH:
return finish_reason_t::kLENGTH;
default:
std::unreachable();
}
}
static auto as_generation_step = [](const tle::Response &r) {
const auto reqId = r.getRequestId();
if (!r.hasError()) [[likely]] {
const auto result = r.getResult();
const auto logits = result.logProbs.value()[0];
return generation_step_t{
reqId,
static_cast<uint32_t>(result.outputTokenIds[0][0]),
logits.back(),
result.isFinal,
as_finish_reason_t(result.finishReasons[0]),
false,
std::string()
};
} else {
return generation_step_t{
reqId,
0,
0.0,
true,
finish_reason_t::kNOT_FINISHED,
true,
std::move(r.getErrorMsg())
};
}
};
class tensorrt_llm_backend_t { class tensorrt_llm_backend_t {
private: private:
backend_t inner_; backend_t inner_;
@ -39,9 +83,7 @@ namespace huggingface::tgi::backends::trtllm {
tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path) tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path)
: inner_(engine_folder, executor_worker_path) {} : inner_(engine_folder, executor_worker_path) {}
size_t num_tokens_ready() const noexcept { size_t num_tokens_ready() const noexcept { return inner_.num_tokens_ready(); }
return inner_.num_tokens_ready();
}
request_id_t submit( request_id_t submit(
rust::Slice<const uint32_t> tokens, rust::Slice<const uint32_t> tokens,
@ -78,41 +120,25 @@ namespace huggingface::tgi::backends::trtllm {
const auto responses = inner_.pull_tokens(); const auto responses = inner_.pull_tokens();
SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size()); SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size());
// Transform tle::Response to GenerationStep
auto steps = std::make_unique<std::vector<generation_step_t>>(); // Transform tle::Response to generation_step_t
std::ranges::transform(responses.begin(), responses.end(), std::back_inserter(*steps), [](const tle::Response &r) { #ifdef __cpp_lib_ranges_to_container
const auto reqId = r.getRequestId(); auto steps = responses | std::views::transform(as_generation_step) | std::ranges::to<std::vector>();
if (!r.hasError()) [[likely]] { #else
const auto result = r.getResult(); auto steps = std::vector<generation_step_t>();
return generation_step_t{ steps.reserve(responses.size());
reqId, std::transform(responses.begin(), responses.end(), std::back_inserter(steps), as_generation_step);
static_cast<uint32_t>(result.outputTokenIds[0][0]), #endif
result.logProbs.value()[0][0], return std::make_unique<std::vector<generation_step_t>>(steps);
result.isFinal,
false,
std::string()
};
} else {
return generation_step_t{
reqId,
0,
0.0,
true,
true,
std::move(r.getErrorMsg())
};
}
});
return steps;
} else { } else {
return std::make_unique<std::vector<generation_step_t>>(); return std::make_unique<std::vector<generation_step_t>>();
} }
} }
void cancel(request_id_t requestId) noexcept { void cancel(request_id_t request_id) noexcept {
SPDLOG_DEBUG("[FFI] cancelling request {:d}", requestId); SPDLOG_DEBUG("[FFI] cancelling request {:d}", request_id);
inner_.cancel(requestId); inner_.cancel(request_id);
} }
}; };
@ -151,11 +177,14 @@ namespace huggingface::tgi::backends::trtllm {
} }
} }
std::unique_ptr<tensorrt_llm_backend_t> create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) { std::unique_ptr<tensorrt_llm_backend_t>
create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) {
std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend); std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend);
return std::make_unique<tensorrt_llm_backend_t>( return std::make_unique<tensorrt_llm_backend_t>(
std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()), std::filesystem::path::format::auto_format), std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()),
std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()), std::filesystem::path::format::auto_format) std::filesystem::path::format::auto_format),
std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()),
std::filesystem::path::format::auto_format)
); );
} }
} }

View File

@ -2,8 +2,8 @@
set -ex set -ex
TRT_VER_BASE="10.6.0" TRT_VER_BASE="10.7.0"
TRT_VER_FULL="${TRT_VER_BASE}.26" TRT_VER_FULL="${TRT_VER_BASE}.23"
CUDA_VER="12.6" CUDA_VER="12.6"
CUDNN_VER="9.5.0.50-1" CUDNN_VER="9.5.0.50-1"
NCCL_VER="2.22.3-1+cuda12.6" NCCL_VER="2.22.3-1+cuda12.6"

View File

@ -0,0 +1,51 @@
from argparse import ArgumentParser
AWS_S3_CACHING_VARIABLES = {
"AWS_ACCESS_KEY_ID": "aws_access_key_id",
"AWS_SECRET_ACCESS_KEY": "aws_secret_access_key",
"AWS_SESSION_TOKEN": "aws_session_token",
"SCCACHE_REGION": "s3_region",
"SCCACHE_BUCKET": "s3_bucket_name",
}
ALL_CACHING_STORAGE_VARIABLES = {"AWS_S3_CACHING_VARIABLES"}
def setup_sccache_locally():
from os import environ
print("Setting up Local Caching Layer")
for target in ALL_CACHING_STORAGE_VARIABLES:
for envvar in globals()[target].keys():
if envvar in environ:
print(f"Deleted {envvar} from environment variables")
del environ[envvar]
def setup_sccache_for_s3():
from os import environ
print("Setting up AWS S3 Caching Layer")
for envvar in AWS_S3_CACHING_VARIABLES.keys():
if envvar not in environ or not environ[envvar] or len(environ[envvar]) == 0:
print(f"Missing definition for environment variable {envvar}")
if __name__ == "__main__":
parser = ArgumentParser("TensorRT-LLM Build Caching Setup")
parser.add_argument(
"--is-gha-build",
type=str,
default="FALSE",
help="Indicate if the build is from Github Actions",
)
# Parse args
args = parser.parse_args()
args.is_gha_build = args.is_gha_build.lower() in {"on", "true", "1"}
if args.is_gha_build:
setup_sccache_for_s3()
else:
setup_sccache_locally()

View File

@ -6,6 +6,26 @@ mod utils;
#[cxx::bridge(namespace = "huggingface::tgi::backends::trtllm")] #[cxx::bridge(namespace = "huggingface::tgi::backends::trtllm")]
mod ffi { mod ffi {
#[cxx_name = "finish_reason_t"]
#[derive(Debug, Clone, Copy)]
pub enum FinishReason {
/// The request is not finished.
#[cxx_name = "kNOT_FINISHED"]
NotFinished = 0u8,
/// The request finished because the end id was generated.
#[cxx_name = "kEND_ID"]
EndTokenId = 1u8,
/// The request finished because a stop word was generated.
#[cxx_name = "kSTOP_WORDS"]
StopWords = 2u8,
/// The request finished because the maximum number of tokens was reached.
#[cxx_name = "kLENGTH"]
MaxLength = 3u8,
}
/// Struct used as shared type between rust and C++ to represent the result /// Struct used as shared type between rust and C++ to represent the result
/// of a single decoding iteration /// of a single decoding iteration
#[cxx_name = "generation_step_t"] #[cxx_name = "generation_step_t"]
@ -15,6 +35,7 @@ mod ffi {
token_id: u32, token_id: u32,
log_prob: f32, log_prob: f32,
is_final: bool, is_final: bool,
finish_reason: FinishReason,
has_error: bool, has_error: bool,
error_msg: String, error_msg: String,
} }
@ -66,3 +87,17 @@ mod ffi {
fn cancel(self: Pin<&mut TensorRtLlmBackendImpl>, request_id: u64); fn cancel(self: Pin<&mut TensorRtLlmBackendImpl>, request_id: u64);
} }
} }
use ffi::FinishReason;
use text_generation_router::FinishReason as InferFinishReason;
impl From<FinishReason> for InferFinishReason {
fn from(reason: FinishReason) -> Self {
match reason {
FinishReason::StopWords => InferFinishReason::StopSequence,
FinishReason::MaxLength => InferFinishReason::Length,
FinishReason::EndTokenId => InferFinishReason::EndOfSequenceToken,
_ => panic!("Cannot convert {reason:?} to text_generation_router::FinishReason"),
}
}
}

View File

@ -18,10 +18,12 @@ use text_generation_router::validation::ValidationError::{
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality, EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
}; };
use text_generation_router::validation::{Chunk, ValidGenerateRequest}; use text_generation_router::validation::{Chunk, ValidGenerateRequest};
use text_generation_router::{FinishReason, Token}; use text_generation_router::Token;
use crate::errors::TensorRtLlmBackendError; use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_backend_from_engine_folder, GenerationStep, TensorRtLlmBackendImpl}; use crate::ffi::{
create_backend_from_engine_folder, FinishReason, GenerationStep, TensorRtLlmBackendImpl,
};
use crate::utils::first_line; use crate::utils::first_line;
type InferResult<T> = Result<T, InferError>; type InferResult<T> = Result<T, InferError>;
@ -40,6 +42,7 @@ struct DecodedToken {
id: u32, id: u32,
log_prob: f32, log_prob: f32,
is_final: bool, is_final: bool,
finish_reason: FinishReason,
} }
impl<'step> TryFrom<&'step GenerationStep> for DecodedToken { impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
@ -51,6 +54,7 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
id: step.token_id, id: step.token_id,
log_prob: step.log_prob, log_prob: step.log_prob,
is_final: step.is_final, is_final: step.is_final,
finish_reason: step.finish_reason,
}) })
} else { } else {
Err(GenerationError(step.error_msg.clone())) Err(GenerationError(step.error_msg.clone()))
@ -192,7 +196,7 @@ fn post_process_decoded_token(
let generated_text = GeneratedText { let generated_text = GeneratedText {
text: text.unwrap(), text: text.unwrap(),
generated_tokens: ctx.tokens.len() as u32, generated_tokens: ctx.tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken, // TODO : Map FinishReason finish_reason: decoded_token.finish_reason.into(),
seed: None, seed: None,
}; };
@ -336,4 +340,8 @@ impl Backend for TensorRtLlmBackendV2 {
async fn health(&self, _: bool) -> bool { async fn health(&self, _: bool) -> bool {
true true
} }
fn name(&self) -> &'static str {
"TensorRT-LLM"
}
} }

View File

@ -11,7 +11,7 @@ use text_generation_router::server::{
get_hub_model_info, legacy_tokenizer_handle, py_resolve_tokenizer, get_hub_model_info, legacy_tokenizer_handle, py_resolve_tokenizer,
}; };
use text_generation_router::usage_stats::UsageStatsLevel; use text_generation_router::usage_stats::UsageStatsLevel;
use text_generation_router::{server, HubTokenizerConfig, Tokenizer}; use text_generation_router::{server, Tokenizer};
/// App Configuration /// App Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -67,11 +67,7 @@ struct Args {
payload_limit: usize, payload_limit: usize,
} }
async fn get_tokenizer( async fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option<Tokenizer> {
tokenizer_name: &str,
tokenizer_config_path: Option<&str>,
revision: Option<&str>,
) -> Option<Tokenizer> {
// Parse Huggingface hub token // Parse Huggingface hub token
let authorization_token = std::env::var("HF_TOKEN") let authorization_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")) .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
@ -182,19 +178,6 @@ async fn get_tokenizer(
} }
}; };
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
// let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
// {
// HubTokenizerConfig::from_file(filename)
// } else {
// tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
// };
// let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
// tracing::warn!("Could not find tokenizer config locally and no API specified");
// HubTokenizerConfig::default()
// });
let tokenizer: Tokenizer = { let tokenizer: Tokenizer = {
use pyo3::prelude::*; use pyo3::prelude::*;
pyo3::Python::with_gil(|py| -> PyResult<()> { pyo3::Python::with_gil(|py| -> PyResult<()> {
@ -292,11 +275,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
} }
// Create the backend // Create the backend
match get_tokenizer( match get_tokenizer(&tokenizer_name, revision.as_deref())
&tokenizer_name,
tokenizer_config_path.as_deref(),
revision.as_deref(),
)
.await .await
.expect("Failed to retrieve tokenizer implementation") .expect("Failed to retrieve tokenizer implementation")
{ {

View File

@ -8,13 +8,13 @@
#include "backend.hpp" #include "backend.hpp"
using namespace huggingface::tgi::backends::trtllm; using namespace huggingface::tgi::backends::trtllm;
TEST_CASE("parse generation_config.json all set", "[generation_config_t]") TEST_CASE("parse generation_config.json all set", "[generation_config_t]")
{ {
const json config_j = {{"temperature", 0.6}, {"top_p", 0.95}, {"eos_token_id", {1,2,3}}}; const json config_j = {{"temperature", 0.6},
{"top_p", 0.95},
{"eos_token_id", {1, 2, 3}}};
const auto generation_config = generation_config_t(config_j); const auto generation_config = generation_config_t(config_j);
REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(0.6, 1e-6)); REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(0.6, 1e-6));
@ -24,8 +24,9 @@ TEST_CASE("parse generation_config.json all set", "[generation_config_t]")
REQUIRE_FALSE(generation_config.stop_words.empty()); REQUIRE_FALSE(generation_config.stop_words.empty());
REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size()); REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());
for (auto [lhs, rhs] : std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1}, {2}, {3}})) for (auto [lhs, rhs]: std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1},
{ {2},
{3}})) {
// Currently we do not support multi-tokens stop words // Currently we do not support multi-tokens stop words
REQUIRE(lhs.size() == 1); REQUIRE(lhs.size() == 1);
REQUIRE(rhs.size() == 1); REQUIRE(rhs.size() == 1);
@ -44,8 +45,9 @@ TEST_CASE("parse generation_config.json default", "[generation_config_t]")
REQUIRE_FALSE(generation_config.stop_words.empty()); REQUIRE_FALSE(generation_config.stop_words.empty());
REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size()); REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());
for (auto [lhs, rhs] : std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1}, {2}, {3}})) for (auto [lhs, rhs]: std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1},
{ {2},
{3}})) {
// Currently we do not support multi-tokens stop words // Currently we do not support multi-tokens stop words
REQUIRE(lhs.size() == 1); REQUIRE(lhs.size() == 1);
REQUIRE(rhs.size() == 1); REQUIRE(rhs.size() == 1);

View File

@ -108,6 +108,10 @@ impl Backend for BackendV2 {
fn start_health(&self) -> bool { fn start_health(&self) -> bool {
true true
} }
fn name(&self) -> &'static str {
"tgi-v2"
}
} }
/// Batching logic /// Batching logic

View File

@ -213,8 +213,7 @@ impl State {
} }
// Pad prefill_token_budget to be a multiple of block size // Pad prefill_token_budget to be a multiple of block size
let prefill_token_budget = let prefill_token_budget = prefill_token_budget.div_ceil(self.block_size) * self.block_size;
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
// Create span for this batch to add context to inference calls // Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
@ -245,9 +244,8 @@ impl State {
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
} else { } else {
// pad to block size // pad to block size
prefill_tokens += ((entry.request.input_length + self.block_size - 1) prefill_tokens +=
/ self.block_size) entry.request.input_length.div_ceil(self.block_size) * self.block_size;
* self.block_size;
} }
if self.requires_padding { if self.requires_padding {
@ -262,8 +260,7 @@ impl State {
}; };
// pad to block size // pad to block size
decode_tokens += decode_tokens += max_new_tokens.div_ceil(self.block_size) * self.block_size;
((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size;
} }
if prefill_tokens > prefill_token_budget if prefill_tokens > prefill_token_budget

View File

@ -115,6 +115,10 @@ impl Backend for BackendV3 {
fn start_health(&self) -> bool { fn start_health(&self) -> bool {
true true
} }
fn name(&self) -> &'static str {
"tgi-v3"
}
} }
/// Batching logic /// Batching logic

View File

@ -165,13 +165,13 @@ impl Allocator for SimpleAllocator {
let (tokens, repeats) = match self.window_size { let (tokens, repeats) = match self.window_size {
None => (tokens, 1), None => (tokens, 1),
Some(window_size) => { Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size; let repeats = tokens.div_ceil(window_size);
let tokens = core::cmp::min(tokens, window_size); let tokens = core::cmp::min(tokens, window_size);
(tokens, repeats as usize) (tokens, repeats as usize)
} }
}; };
// Pad to a multiple of block size // Pad to a multiple of block size
let required_blocks = (tokens + self.block_size - 1) / self.block_size; let required_blocks = tokens.div_ceil(self.block_size);
(required_blocks, repeats) (required_blocks, repeats)
}; };

View File

@ -257,8 +257,7 @@ impl State {
} }
// Pad prefill_token_budget to be a multiple of block size // Pad prefill_token_budget to be a multiple of block size
let prefill_token_budget = let prefill_token_budget = prefill_token_budget.div_ceil(self.block_size) * self.block_size;
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
// Create span for this batch to add context to inference calls // Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);

View File

@ -103,7 +103,7 @@ impl Allocator for RadixAllocator {
let prefix_len = blocks.len() * self.block_size as usize; let prefix_len = blocks.len() * self.block_size as usize;
let suffix_len = tokens - prefix_len as u32; let suffix_len = tokens - prefix_len as u32;
let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size; let suffix_blocks = suffix_len.div_ceil(self.block_size);
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}"); tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0", "name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0" "url": "https://www.apache.org/licenses/LICENSE-2.0"
}, },
"version": "3.0.2-dev0" "version": "3.1.1-dev0"
}, },
"paths": { "paths": {
"/": { "/": {

View File

@ -13,6 +13,8 @@
title: Using TGI with Intel Gaudi title: Using TGI with Intel Gaudi
- local: installation_inferentia - local: installation_inferentia
title: Using TGI with AWS Inferentia title: Using TGI with AWS Inferentia
- local: installation_tpu
title: Using TGI with Google TPUs
- local: installation_intel - local: installation_intel
title: Using TGI with Intel GPUs title: Using TGI with Intel GPUs
- local: installation - local: installation

View File

@ -4,8 +4,13 @@ The NVIDIA TensorRT-LLM (TRTLLM) backend is a high-performance backend for LLMs
that uses NVIDIA's TensorRT library for inference acceleration. that uses NVIDIA's TensorRT library for inference acceleration.
It makes use of specific optimizations for NVIDIA GPUs, such as custom kernels. It makes use of specific optimizations for NVIDIA GPUs, such as custom kernels.
To use the TRTLLM backend you need to compile `engines` for the models you want to use. To use the TRTLLM backend **you need to compile** `engines` for the models you want to use.
Each `engine` must be compiled on the same GPU architecture that you will use for inference. Each `engine` must be compiled for a given set of:
- GPU architecture that you will use for inference (e.g. A100, L40, etc.)
- Maximum batch size
- Maximum input length
- Maximum output length
- Maximum beams width
## Supported models ## Supported models
@ -19,63 +24,159 @@ want to use.
```bash ```bash
MODEL_NAME="meta-llama/Llama-3.1-8B-Instruct" MODEL_NAME="meta-llama/Llama-3.1-8B-Instruct"
DESTINATION="/tmp/engines/$MODEL_NAME"
# Install huggingface_cli HF_TOKEN="hf_xxx"
python -m pip install huggingface-cli[hf_transfer]
# Login to the Hugging Face Hub
huggingface-cli login
# Create a directory to store the model
mkdir -p /tmp/models/$MODEL_NAME
# Create a directory to store the compiled engine
mkdir -p /tmp/engines/$MODEL_NAME
# Download the model
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download --local-dir /tmp/models/$MODEL_NAME $MODEL_NAME
# Compile the engine using Optimum-NVIDIA # Compile the engine using Optimum-NVIDIA
# This will create a compiled engine in the /tmp/engines/meta-llama/Llama-3.1-8B-Instruct
# directory for 1 GPU
docker run \ docker run \
--rm \ --rm \
-it \ -it \
--gpus=1 \ --gpus=1 \
-v /tmp/models/$MODEL_NAME:/model \ --shm-size=1g \
-v /tmp/engines/$MODEL_NAME:/engine \ -v "$DESTINATION":/engine \
huggingface/optimum-nvidia \ -e HF_TOKEN=$HF_TOKEN \
optimum-cli export trtllm \ -e HF_HUB_ENABLE_HF_TRANSFER=1 \
huggingface/optimum-nvidia:v0.1.0b9-py310 \
bash -c "optimum-cli export trtllm \
--tp=1 \ --tp=1 \
--pp=1 \ --pp=1 \
--max-batch-size=128 \ --max-batch-size=64 \
--max-input-length 4096 \ --max-input-length 4096 \
--max-output-length 8192 \ --max-output-length 8192 \
--max-beams-width=1 \ --max-beams-width=1 \
--destination /engine \ --destination /tmp/engine \
$MODEL_NAME $MODEL_NAME && cp -rL /tmp/engine/* /engine/"
``` ```
Your compiled engine will be saved in the `/tmp/engines/$MODEL_NAME` directory. Your compiled engine will be saved in the `/tmp/engines/$MODEL_NAME` directory, in a subfolder named after the GPU used to compile the model.
## Using the TRTLLM backend ## Using the TRTLLM backend
Run TGI-TRTLLM Docker image with the compiled engine: Run TGI-TRTLLM Docker image with the compiled engine:
```bash ```bash
MODEL_NAME="meta-llama/Llama-3.1-8B-Instruct"
DESTINATION="/tmp/engines/$MODEL_NAME"
HF_TOKEN="hf_xxx"
docker run \ docker run \
--gpus 1 \ --gpus 1 \
--shm-size=1g \
-it \ -it \
--rm \ --rm \
-p 3000:3000 \ -p 3000:3000 \
-e MODEL=$MODEL_NAME \ -e MODEL=$MODEL_NAME \
-e PORT=3000 \ -e PORT=3000 \
-e HF_TOKEN='hf_XXX' \ -e HF_TOKEN=$HF_TOKEN \
-v /tmp/engines/$MODEL_NAME:/data \ -v "$DESTINATION"/<YOUR_GPU_ARCHITECTURE>/engines:/data \
ghcr.io/huggingface/text-generation-inference:latest-trtllm \ ghcr.io/huggingface/text-generation-inference:latest-trtllm \
--executor-worker executorWorker \ --model-id /data/ \
--model-id /data/$MODEL_NAME --tokenizer-name $MODEL_NAME
``` ```
## Development ## Development
To develop TRTLLM backend, you can use [dev containers](https://containers.dev/) located in To develop TRTLLM backend, you can use [dev containers](https://containers.dev/) with the following `.devcontainer.json` file:
`.devcontainer` directory. ```json
{
"name": "CUDA",
"build": {
"dockerfile": "Dockerfile_trtllm",
"context": ".."
},
"remoteEnv": {
"PATH": "${containerEnv:PATH}:/usr/local/cuda/bin",
"LD_LIBRARY_PATH": "$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64",
"XLA_FLAGS": "--xla_gpu_cuda_data_dir=/usr/local/cuda"
},
"customizations" : {
"jetbrains" : {
"backend" : "CLion"
}
}
}
```
and `Dockerfile_trtllm`:
```Dockerfile
ARG cuda_arch_list="75-real;80-real;86-real;89-real;90-real"
ARG build_type=release
ARG ompi_version=4.1.7
# CUDA dependent dependencies resolver stage
FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu24.04 AS cuda-builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
build-essential \
cmake \
curl \
gcc-14 \
g++-14 \
git \
git-lfs \
lld \
libssl-dev \
libucx-dev \
libasan8 \
libubsan1 \
ninja-build \
pkg-config \
pipx \
python3 \
python3-dev \
python3-setuptools \
tar \
wget --no-install-recommends && \
pipx ensurepath
ENV TGI_INSTALL_PREFIX=/usr/local/tgi
ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt
# Install OpenMPI
FROM cuda-builder AS mpi-builder
WORKDIR /opt/src/mpi
ARG ompi_version
ENV OMPI_VERSION=${ompi_version}
ENV OMPI_TARBALL_FILENAME=openmpi-${OMPI_VERSION}.tar.bz2
ADD --checksum=sha256:54a33cb7ad81ff0976f15a6cc8003c3922f0f3d8ceed14e1813ef3603f22cd34 \
https://download.open-mpi.org/release/open-mpi/v4.1/${OMPI_TARBALL_FILENAME} .
RUN tar --strip-components=1 -xf ${OMPI_TARBALL_FILENAME} &&\
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \
make -j all && \
make install && \
rm -rf ${OMPI_TARBALL_FILENAME}/..
# Install TensorRT
FROM cuda-builder AS trt-builder
COPY backends/trtllm/scripts/install_tensorrt.sh /opt/install_tensorrt.sh
RUN chmod +x /opt/install_tensorrt.sh && \
/opt/install_tensorrt.sh
# Build Backend
FROM cuda-builder AS tgi-builder
WORKDIR /usr/src/text-generation-inference
# Scoped global args reuse
ARG cuda_arch_list
ARG build_type
ARG sccache_gha_enabled
ARG actions_cache_url
ARG actions_runtime_token
# Install Rust
ENV PATH="/root/.cargo/bin:$PATH"
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \
chmod -R a+w /root/.rustup && \
chmod -R a+w /root/.cargo && \
cargo install sccache --locked
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig"
ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt"
ENV USE_LLD_LINKER=ON
ENV CUDA_ARCH_LIST=${cuda_arch_list}
```

View File

@ -19,6 +19,6 @@ docker run --gpus all \
--shm-size 1g \ --shm-size 1g \
-e HF_TOKEN=$token \ -e HF_TOKEN=$token \
-p 8080:80 \ -p 8080:80 \
-v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.1 \ -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.0 \
--model-id $model --model-id $model
``` ```

View File

@ -19,7 +19,7 @@ bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models.
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇 In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
```bash ```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.1 --model-id $model --quantize bitsandbytes docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.0 --model-id $model --quantize bitsandbytes
``` ```
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load. 4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
@ -27,7 +27,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇 In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
```bash ```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.1 --model-id $model --quantize bitsandbytes-nf4 docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.0 --model-id $model --quantize bitsandbytes-nf4
``` ```
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes). You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
@ -48,7 +48,7 @@ $$({\hat{W}_{l}}^{*} = argmin_{\hat{W_{l}}} ||W_{l}X-\hat{W}_{l}X||^{2}_{2})$$
TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇 TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇
```bash ```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.1 --model-id $model --quantize gptq docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.0 --model-id $model --quantize gptq
``` ```
Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI. Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI.

View File

@ -27,7 +27,7 @@ You can check a few existing fine-tunes for popular models:
- [text-generation-inference/Mistral-7B-Instruct-v0.2-medusa](https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa) - [text-generation-inference/Mistral-7B-Instruct-v0.2-medusa](https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa)
In order to create your own medusa heads for your own finetune, you should check own the original medusa repo. [../basic_tutorials/train_medusa.md](../basic_tutorials/train_medusa.md) In order to create your own medusa heads for your own finetune, you should check own the original medusa repo. Read for more in [Train Medusa](../basic_tutorials/train_medusa#training).
In order to use medusa models in TGI, simply point to a medusa enabled model, and everything will load automatically. In order to use medusa models in TGI, simply point to a medusa enabled model, and everything will load automatically.

View File

@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--device=/dev/kfd --device=/dev/dri --group-add video \ --device=/dev/kfd --device=/dev/dri --group-add video \
--ipc=host --shm-size 256g --net host -v $volume:/data \ --ipc=host --shm-size 256g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.0.1-rocm \ ghcr.io/huggingface/text-generation-inference:3.1.0-rocm \
--model-id $model --model-id $model
``` ```

View File

@ -12,7 +12,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm --privileged --cap-add=sys_nice \ docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \ --device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \ --ipc=host --shm-size 1g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.0.1-intel-xpu \ ghcr.io/huggingface/text-generation-inference:3.1.0-intel-xpu \
--model-id $model --cuda-graphs 0 --model-id $model --cuda-graphs 0
``` ```
@ -29,7 +29,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm --privileged --cap-add=sys_nice \ docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \ --device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \ --ipc=host --shm-size 1g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.0.1-intel-cpu \ ghcr.io/huggingface/text-generation-inference:3.1.0-intel-cpu \
--model-id $model --cuda-graphs 0 --model-id $model --cuda-graphs 0
``` ```

View File

@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \ docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.0.1 \ ghcr.io/huggingface/text-generation-inference:3.1.0 \
--model-id $model --model-id $model
``` ```

View File

@ -0,0 +1,3 @@
# Using TGI with Google TPUs
Check out this [guide](https://huggingface.co/docs/optimum-tpu) on how to serve models with TGI on TPUs.

View File

@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.0.1 \ ghcr.io/huggingface/text-generation-inference:3.1.0 \
--model-id $model --model-id $model
``` ```
@ -96,7 +96,7 @@ curl 127.0.0.1:8080/generate \
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more. To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
```bash ```bash
docker run ghcr.io/huggingface/text-generation-inference:3.0.1 --help docker run ghcr.io/huggingface/text-generation-inference:3.1.0 --help
``` ```
</Tip> </Tip>

View File

@ -163,7 +163,7 @@ hub = {
# create Hugging Face Model Class # create Hugging Face Model Class
huggingface_model = HuggingFaceModel( huggingface_model = HuggingFaceModel(
image_uri=get_huggingface_llm_image_uri("huggingface",version="3.0.1"), image_uri=get_huggingface_llm_image_uri("huggingface",version="3.1.0"),
env=hub, env=hub,
role=role, role=role,
) )

View File

@ -4,6 +4,7 @@
Text Generation Inference enables serving optimized models. The following sections list which models (VLMs & LLMs) are supported. Text Generation Inference enables serving optimized models. The following sections list which models (VLMs & LLMs) are supported.
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2) - [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
- [Deepseek V3](https://huggingface.co/deepseek-ai/DeepSeek-V3)
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal) - [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
- [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (Multimodal) - [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (Multimodal)
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal) - [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)

View File

@ -3,7 +3,7 @@
Text Generation Inference collects anonymous usage statistics to help us improve the service. The collected data is used to improve TGI and to understand what causes failures. The data is collected transparently and any sensitive information is omitted. Text Generation Inference collects anonymous usage statistics to help us improve the service. The collected data is used to improve TGI and to understand what causes failures. The data is collected transparently and any sensitive information is omitted.
Data is sent twice, once on server startup and once when server stops. Also, usage statistics are only enabled when TGI is running in docker to avoid collecting data then TGI runs directly on the host machine. Usage statistics are collected only when TGI is running in a Docker container. This prevents data collection when TGI is run directly on the host machine. The collected data includes startup and shutdown events, as well as a heartbeat signal sent every 15 minutes.
## What data is collected ## What data is collected

View File

@ -108,11 +108,11 @@
"pre-commit-hooks": "pre-commit-hooks_3" "pre-commit-hooks": "pre-commit-hooks_3"
}, },
"locked": { "locked": {
"lastModified": 1732039290, "lastModified": 1734429562,
"narHash": "sha256-LQKY7bShf2H9kJouxa9ZspfdrulnZF9o4kLTqGqCDYM=", "narHash": "sha256-V2XNs3Ir8WXNHdocfzkR/fu0FzkZ9uTDJkVecxJrGmQ=",
"owner": "nix-community", "owner": "nix-community",
"repo": "crate2nix", "repo": "crate2nix",
"rev": "9ff208ce7f5a482272b1bcefbe363c772d7ff914", "rev": "8537c2d7cb623679aaeff62c4c4c43a91566ab09",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -305,11 +305,11 @@
}, },
"flake-compat_4": { "flake-compat_4": {
"locked": { "locked": {
"lastModified": 1696426674, "lastModified": 1733328505,
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=", "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
"owner": "edolstra", "owner": "edolstra",
"repo": "flake-compat", "repo": "flake-compat",
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33", "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -718,11 +718,11 @@
}, },
"nixpkgs_6": { "nixpkgs_6": {
"locked": { "locked": {
"lastModified": 1732034459, "lastModified": 1737453259,
"narHash": "sha256-Zais/zMRuJdlALidkUgEuasXOd37ZZLqkPkF9bIYSrY=", "narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=",
"owner": "danieldk", "owner": "danieldk",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "40280e7bf9743cdf563494db4ece2a43aa674fa8", "rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -853,11 +853,11 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1732242723, "lastModified": 1737685583,
"narHash": "sha256-NWI8csIK0ujFlFuEXKnoc+7hWoCiEtINK9r48LUUMeU=", "narHash": "sha256-p+NVABRpGi+pT+xxf9HcLcFVxG6L+vEEy+NwzB9T0f8=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "a229311fcb45b88a95fdfa5cecd8349c809a272a", "rev": "eb64cbcc8eee0fa87ebded92805280d2ec97415a",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -978,11 +978,11 @@
"nixpkgs": "nixpkgs_6" "nixpkgs": "nixpkgs_6"
}, },
"locked": { "locked": {
"lastModified": 1736436388, "lastModified": 1738323634,
"narHash": "sha256-CIyxVPpM9RrSwthNT/4DQ10YPk/uwzP7AeE83kBNsrE=", "narHash": "sha256-lKPzgEm7pEuQJVhacsxFHqg1MOtrUMZvr+9IuJzC5J4=",
"owner": "huggingface", "owner": "huggingface",
"repo": "text-generation-inference-nix", "repo": "text-generation-inference-nix",
"rev": "5103c3fb1f9ad1fd33b6e09ff05e957884b112d5", "rev": "eb5fede2756f544f75e01f55a4097f9c9a8c5005",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -562,6 +562,7 @@ def launcher(event_loop):
docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]]) docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
] ]
client.api.timeout = 1000
container = client.containers.run( container = client.containers.run(
DOCKER_IMAGE, DOCKER_IMAGE,
command=args, command=args,
@ -573,7 +574,7 @@ def launcher(event_loop):
devices=devices, devices=devices,
volumes=volumes, volumes=volumes,
ports={"80/tcp": port}, ports={"80/tcp": port},
healthcheck={"timeout": int(60 * 1e9), "retries": 2}, # 60s healthcheck={"timeout": int(180 * 1e9), "retries": 2}, # 60s
shm_size="1G", shm_size="1G",
) )

View File

@ -0,0 +1,73 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 2284,
"logprob": -0.9355469,
"special": false,
"text": "():"
},
{
"id": 303,
"logprob": -0.40795898,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": -0.27954102,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": -0.6142578,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": -0.68310547,
"special": false,
"text": "Hello"
},
{
"id": 10914,
"logprob": -1.4599609,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.80126953,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.625,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -0.23242188,
"special": false,
"text": "\n"
},
{
"id": 610,
"logprob": -1.2294922,
"special": false,
"text": "def"
}
],
"top_tokens": null
},
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
}

View File

@ -0,0 +1,373 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 60,
"prefill": [],
"seed": 0,
"tokens": [
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -0.7944336,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -0.1796875,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": 0.0,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": 0.0,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": 0.0,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": 0.0,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": 0.0,
"special": false,
"text": "slide"
},
{
"id": 100,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 700,
"logprob": 0.0,
"special": false,
"text": "type"
},
{
"id": 582,
"logprob": 0.0,
"special": false,
"text": "\":"
},
{
"id": 332,
"logprob": 0.0,
"special": false,
"text": " \""
},
{
"id": 7277,
"logprob": -0.06994629,
"special": false,
"text": "slide"
},
{
"id": 3667,
"logprob": 0.0,
"special": false,
"text": "\"}"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": 0.0,
"special": false,
"text": "#"
},
{
"id": 607,
"logprob": -0.8261719,
"special": false,
"text": " #"
},
{
"id": 244,
"logprob": -1.8574219,
"special": false,
"text": " "
},
{
"id": 55,
"logprob": -1.4541016,
"special": false,
"text": "2"
},
{
"id": 51,
"logprob": 0.0,
"special": false,
"text": "."
},
{
"id": 6208,
"logprob": -0.9794922,
"special": false,
"text": " What"
},
{
"id": 458,
"logprob": 0.0,
"special": false,
"text": " is"
},
{
"id": 341,
"logprob": 0.0,
"special": false,
"text": " the"
},
{
"id": 10609,
"logprob": -0.69189453,
"special": false,
"text": " difference"
},
{
"id": 3761,
"logprob": 0.0,
"special": false,
"text": " between"
},
{
"id": 331,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 1168,
"logprob": -0.27172852,
"special": false,
"text": " list"
},
{
"id": 480,
"logprob": 0.0,
"special": false,
"text": " and"
},
{
"id": 331,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 8871,
"logprob": 0.0,
"special": false,
"text": " tuple"
},
{
"id": 68,
"logprob": 0.0,
"special": false,
"text": "?"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -1.3359375,
"special": false,
"text": "#"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": 0.0,
"special": false,
"text": "#"
},
{
"id": 449,
"logprob": -0.03164673,
"special": false,
"text": " -"
},
{
"id": 418,
"logprob": -1.0947266,
"special": false,
"text": " A"
},
{
"id": 1168,
"logprob": 0.0,
"special": false,
"text": " list"
},
{
"id": 458,
"logprob": 0.0,
"special": false,
"text": " is"
},
{
"id": 331,
"logprob": -0.3305664,
"special": false,
"text": " a"
},
{
"id": 14792,
"logprob": 0.0,
"special": false,
"text": " mutable"
},
{
"id": 6645,
"logprob": -0.40478516,
"special": false,
"text": " sequence"
},
{
"id": 451,
"logprob": 0.0,
"special": false,
"text": " of"
},
{
"id": 4725,
"logprob": -0.50390625,
"special": false,
"text": " elements"
},
{
"id": 49,
"logprob": -2.1269531,
"special": false,
"text": ","
},
{
"id": 2236,
"logprob": -0.1427002,
"special": false,
"text": " while"
},
{
"id": 331,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 8871,
"logprob": 0.0,
"special": false,
"text": " tuple"
},
{
"id": 458,
"logprob": 0.0,
"special": false,
"text": " is"
},
{
"id": 619,
"logprob": 0.0,
"special": false,
"text": " an"
},
{
"id": 26079,
"logprob": 0.0,
"special": false,
"text": " immutable"
},
{
"id": 6645,
"logprob": 0.0,
"special": false,
"text": " sequence"
},
{
"id": 451,
"logprob": 0.0,
"special": false,
"text": " of"
},
{
"id": 4725,
"logprob": 0.0,
"special": false,
"text": " elements"
},
{
"id": 51,
"logprob": 0.0,
"special": false,
"text": "."
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": 0.0,
"special": false,
"text": "#"
},
{
"id": 449,
"logprob": 0.0,
"special": false,
"text": " -"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide_type\": \"slide\"}\n# # 2. What is the difference between a list and a tuple?\n#\n# - A list is a mutable sequence of elements, while a tuple is an immutable sequence of elements.\n# -"
}

View File

@ -0,0 +1,294 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 222,
"logprob": -1.9091797,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -1.0478516,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -3.015625,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": -1.4228516,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -1.1025391,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": -0.0008444786,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": -8.8095665e-05,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": -0.5810547,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": -0.00022864342,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": -0.00030994415,
"special": false,
"text": "slide"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 222,
"logprob": -1.9091797,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -1.0478516,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -3.015625,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": -1.4228516,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -1.1025391,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": -0.0008444786,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": -8.8095665e-05,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": -0.5810547,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": -0.00022864342,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": -0.00030994415,
"special": false,
"text": "slide"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 222,
"logprob": -1.9091797,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -1.0478516,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -3.015625,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": -1.4228516,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -1.1025391,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": -0.0008444786,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": -8.8095665e-05,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": -0.5810547,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": -0.00022864342,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": -0.00030994415,
"special": false,
"text": "slide"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 222,
"logprob": -1.9091797,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -1.0478516,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -3.015625,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": -1.4228516,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -1.1025391,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": -0.0008444786,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": -8.8095665e-05,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": -0.5810547,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": -0.00022864342,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": -0.00030994415,
"special": false,
"text": "slide"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
}
]

View File

@ -0,0 +1,71 @@
{
"details": {
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 100,
"logprob": -0.9824219,
"special": false,
"text": "_"
},
{
"id": 5879,
"logprob": -0.3017578,
"special": false,
"text": "world"
},
{
"id": 2284,
"logprob": -0.68652344,
"special": false,
"text": "():"
},
{
"id": 303,
"logprob": -0.27734375,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": -0.4482422,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": -0.54248047,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": -0.4296875,
"special": false,
"text": "Hello"
},
{
"id": 10914,
"logprob": -0.8544922,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.7573242,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.81347656,
"special": false,
"text": "\n"
}
]
},
"generated_text": "_world():\n print(\"Hello World!\")\n"
}

View File

@ -0,0 +1,79 @@
import pytest
import requests
@pytest.fixture(scope="module")
def flash_starcoder2_handle(launcher):
with launcher(
"bigcode/starcoder2-3b", lora_adapters=["smangrul/starcoder-3b-hugcoder"]
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_starcoder2(flash_starcoder2_handle):
await flash_starcoder2_handle.health(300)
return flash_starcoder2_handle.client
@pytest.mark.asyncio
async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
response = await flash_starcoder2.generate(
"def print_hello", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot):
response = await flash_starcoder2.generate(
"who are you?",
max_new_tokens=60,
temperature=0.2,
top_p=0.95,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 60
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_starcoder2_load(
flash_starcoder2, generate_load, response_snapshot
):
responses = await generate_load(
flash_starcoder2, "who are you?", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot
@pytest.mark.asyncio
async def test_flash_starcoder2_with_hugcode_adapter(
flash_starcoder2, response_snapshot
):
response = requests.post(
f"{flash_starcoder2.base_url}/generate",
headers=flash_starcoder2.headers,
json={
"inputs": "def print_hello",
"parameters": {
"max_new_tokens": 10,
"adapter_id": "smangrul/starcoder-3b-hugcoder",
"details": True,
},
},
)
assert response.status_code == 200
data = response.json()
assert data["generated_text"] == '_world():\n print("Hello World!")\n'
assert data == response_snapshot

View File

@ -25,21 +25,23 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap
assert response == generous_response_snapshot assert response == generous_response_snapshot
@pytest.mark.release # Deactivated because it's flaky
@pytest.mark.asyncio # Only this model seems affected and it's only a logprob precision issue.
async def test_flash_starcoder_gptq_default_params( # @pytest.mark.release
flash_starcoder_gptq, generous_response_snapshot # @pytest.mark.asyncio
): # async def test_flash_starcoder_gptq_default_params(
response = await flash_starcoder_gptq.generate( # flash_starcoder_gptq, generous_response_snapshot
"def geometric_mean(L: List[float]):", # ):
max_new_tokens=20, # response = await flash_starcoder_gptq.generate(
temperature=0.2, # "def geometric_mean(L: List[float]):",
top_p=0.95, # max_new_tokens=20,
decoder_input_details=True, # temperature=0.2,
seed=0, # top_p=0.95,
) # decoder_input_details=True,
assert response.details.generated_tokens == 2 # seed=0,
assert response == generous_response_snapshot # )
# assert response.details.generated_tokens == 2
# assert response == generous_response_snapshot
@pytest.mark.release @pytest.mark.release

View File

@ -5,7 +5,6 @@ use hf_hub::{
}; };
use nix::sys::signal::{self, Signal}; use nix::sys::signal::{self, Signal};
use nix::unistd::Pid; use nix::unistd::Pid;
use regex::Regex;
use serde::Deserialize; use serde::Deserialize;
use std::env; use std::env;
use std::ffi::OsString; use std::ffi::OsString;
@ -144,7 +143,9 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
} }
} }
let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) { let fallback_attention = if compute_capability.is_none()
|| matches!(compute_capability, Some((major, _)) if major < 8)
{
"paged" "paged"
} else { } else {
"flashdecoding" "flashdecoding"
@ -1631,8 +1632,10 @@ enum Gpu {
L40, L40,
L40S, L40S,
A10G, A10G,
A40,
H100, H100,
A100, A100,
H200,
Unknown(String), Unknown(String),
} }
@ -1651,6 +1654,7 @@ impl From<&str> for Gpu {
"nvidia-l40" => Gpu::L40, "nvidia-l40" => Gpu::L40,
"nvidia-l40s" => Gpu::L40S, "nvidia-l40s" => Gpu::L40S,
"nvidia-a10g" => Gpu::A10G, "nvidia-a10g" => Gpu::A10G,
"nvidia-a40" => Gpu::A40,
"nvidia-h100-80gb-hbm3" => Gpu::H100, "nvidia-h100-80gb-hbm3" => Gpu::H100,
"nvidia-h100-nvl" => Gpu::H100, "nvidia-h100-nvl" => Gpu::H100,
"nvidia-h100" => Gpu::H100, "nvidia-h100" => Gpu::H100,
@ -1658,6 +1662,7 @@ impl From<&str> for Gpu {
"nvidia-a100-sxm4-40gb" => Gpu::A100, "nvidia-a100-sxm4-40gb" => Gpu::A100,
"nvidia-a100-80gb-pcie" => Gpu::A100, "nvidia-a100-80gb-pcie" => Gpu::A100,
"nvidia-a100" => Gpu::A100, "nvidia-a100" => Gpu::A100,
"nvidia-h200" => Gpu::H200,
card => Gpu::Unknown(card.to_string()), card => Gpu::Unknown(card.to_string()),
} }
} }
@ -1672,8 +1677,10 @@ impl std::fmt::Display for Gpu {
Gpu::L40 => write!(f, "nvida-l40"), Gpu::L40 => write!(f, "nvida-l40"),
Gpu::L40S => write!(f, "nvida-l40s"), Gpu::L40S => write!(f, "nvida-l40s"),
Gpu::A10G => write!(f, "nvidia-a10g"), Gpu::A10G => write!(f, "nvidia-a10g"),
Gpu::A40 => write!(f, "nvidia-a40"),
Gpu::H100 => write!(f, "nvidia-h100-80fb-hbm3"), Gpu::H100 => write!(f, "nvidia-h100-80fb-hbm3"),
Gpu::A100 => write!(f, "nvida-a100-sxm4-80gb"), Gpu::A100 => write!(f, "nvida-a100-sxm4-80gb"),
Gpu::H200 => write!(f, "nvida-h200"),
Gpu::Unknown(card) => write!(f, "{}", card), Gpu::Unknown(card) => write!(f, "{}", card),
} }
} }
@ -1695,11 +1702,16 @@ impl ComputeType {
Gpu::L40S => Some(363 * 10u64.pow(12)), Gpu::L40S => Some(363 * 10u64.pow(12)),
// https://www.nvidia.com/en-us/data-center/products/a10-gpu/ // https://www.nvidia.com/en-us/data-center/products/a10-gpu/
Gpu::A10G => Some(125 * 10u64.pow(12)), Gpu::A10G => Some(125 * 10u64.pow(12)),
// https://www.nvidia.com/en-us/data-center/a40/
// https://images.nvidia.com/content/Solutions/data-center/a40/nvidia-a40-datasheet.pdf
Gpu::A40 => Some(149 * 10u64.pow(12)),
// https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
Gpu::A100 => Some(312 * 10u64.pow(12)),
// https://www.nvidia.com/en-us/data-center/h100/ // https://www.nvidia.com/en-us/data-center/h100/
// https://www.techpowerup.com/gpu-specs/docs/nvidia-gh100-architecture.pdf // https://www.techpowerup.com/gpu-specs/docs/nvidia-gh100-architecture.pdf
Gpu::H100 => Some(900 * 10u64.pow(12)), Gpu::H100 => Some(900 * 10u64.pow(12)),
// https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf // https://www.nvidia.com/en-us/data-center/h200/
Gpu::A100 => Some(312 * 10u64.pow(12)), Gpu::H200 => Some(989 * 10u64.pow(12)),
Gpu::Unknown(card) => { Gpu::Unknown(card) => {
tracing::warn!("Unkown compute for card {card}"); tracing::warn!("Unkown compute for card {card}");
None None
@ -2079,14 +2091,7 @@ fn main() -> Result<(), LauncherError> {
let cuda_graphs = match (&args.cuda_graphs, &quantize) { let cuda_graphs = match (&args.cuda_graphs, &quantize) {
(Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(), (Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
#[allow(deprecated)] #[allow(deprecated)]
( (None, Some(Quantization::Bitsandbytes)) => {
None,
Some(
Quantization::Bitsandbytes
| Quantization::BitsandbytesNf4
| Quantization::BitsandbytesFp4,
),
) => {
tracing::warn!("Bitsandbytes doesn't work with cuda graphs, deactivating them"); tracing::warn!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
vec![] vec![]
} }
@ -2176,11 +2181,12 @@ fn main() -> Result<(), LauncherError> {
} }
// capture adapter_id, path, revision in format of adapter_id=path@revision // capture adapter_id, path, revision in format of adapter_id=path@revision
let re = Regex::new(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$").unwrap(); // path is disabled beforehand.
if let Some(caps) = re.captures(adapter) { let mut splits = adapter.split("@");
let adapter_id = caps.get(1).map_or("", |m| m.as_str()); let adapter_id = splits.next().ok_or_else(|| {
let revision = caps.get(3).map(|m| m.as_str()); LauncherError::ArgumentValidation("Missing adapter id".to_string())
})?;
let revision = splits.next();
download_convert_model( download_convert_model(
adapter_id, adapter_id,
revision, revision,
@ -2190,12 +2196,6 @@ fn main() -> Result<(), LauncherError> {
running.clone(), running.clone(),
false, // avoid merging lora adapters if using multi-lora false, // avoid merging lora adapters if using multi-lora
)?; )?;
} else {
return Err(LauncherError::ArgumentValidation(format!(
"Invalid LoRA adapter format: {}",
adapter
)));
}
} }
} }

View File

@ -224,6 +224,8 @@ pub enum Config {
Qwen2, Qwen2,
Opt, Opt,
T5, T5,
DeepseekV2,
DeepseekV3,
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]

View File

@ -40,6 +40,8 @@ pub trait Backend {
fn start_health(&self) -> bool { fn start_health(&self) -> bool {
false false
} }
fn name(&self) -> &'static str;
} }
/// Inference struct /// Inference struct

View File

@ -79,7 +79,7 @@ impl TokenizerTrait for tokenizers::Tokenizer {
} }
} }
impl<'a> TokenizerTrait for PyTokenizer<'a> { impl TokenizerTrait for PyTokenizer<'_> {
fn encode_trait( fn encode_trait(
&self, &self,
query: String, query: String,
@ -460,7 +460,7 @@ pub struct CompletionRequest {
pub prompt: Prompt, pub prompt: Prompt,
/// The maximum number of tokens that can be generated in the chat completion. /// The maximum number of tokens that can be generated in the chat completion.
#[serde(default, alias = "max_completion_tokens")] #[serde(default)]
#[schema(default = "1024", example = "32")] #[schema(default = "1024", example = "32")]
pub max_tokens: Option<u32>, pub max_tokens: Option<u32>,
@ -840,7 +840,7 @@ pub(crate) struct ChatRequest {
pub top_logprobs: Option<u32>, pub top_logprobs: Option<u32>,
/// The maximum number of tokens that can be generated in the chat completion. /// The maximum number of tokens that can be generated in the chat completion.
#[serde(default)] #[serde(default, alias = "max_completion_tokens")]
#[schema(default = "1024", example = "32")] #[schema(default = "1024", example = "32")]
pub max_tokens: Option<u32>, pub max_tokens: Option<u32>,

View File

@ -54,6 +54,9 @@ use std::fs::File;
use std::io::BufReader; use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error; use thiserror::Error;
use tokio::select; use tokio::select;
use tokio::signal; use tokio::signal;
@ -1819,9 +1822,9 @@ pub async fn run(
HubTokenizerConfig::default() HubTokenizerConfig::default()
}); });
let tokenizer: Tokenizer = { let tokenizer: Result<Tokenizer, WebServerError> = {
use pyo3::prelude::*; use pyo3::prelude::*;
pyo3::Python::with_gil(|py| -> PyResult<()> { Python::with_gil(|py| -> PyResult<()> {
py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), trust_remote_code)?; py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), trust_remote_code)?;
Ok(()) Ok(())
}) })
@ -1832,16 +1835,16 @@ pub async fn run(
let out = legacy_tokenizer_handle(config_filename.as_ref()); let out = legacy_tokenizer_handle(config_filename.as_ref());
out.ok_or(err) out.ok_or(err)
}) })
.expect("We cannot load a tokenizer"); .map_err(|_| WebServerError::Tokenizer("Unable to load tokenizer.".to_string()))?;
let filename = "out/tokenizer.json"; let filename = "out/tokenizer.json";
if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) { if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) {
Tokenizer::Rust(tok) Ok(Tokenizer::Rust(tok))
} else { } else {
Tokenizer::Python { Ok(Tokenizer::Python {
tokenizer_name: tokenizer_name.clone(), tokenizer_name: tokenizer_name.clone(),
revision: revision.clone(), revision: revision.clone(),
trust_remote_code, trust_remote_code,
} })
} }
}; };
@ -1895,17 +1898,34 @@ pub async fn run(
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
usage_stats_level, usage_stats_level,
backend.name(),
); );
Some(usage_stats::UserAgent::new(reduced_args)) Some(usage_stats::UserAgent::new(reduced_args))
} }
_ => None, _ => None,
}; };
if let Some(ref ua) = user_agent { let stop_usage_thread = Arc::new(AtomicBool::new(false));
let stop_usage_thread_clone = stop_usage_thread.clone();
if let Some(ua) = user_agent.clone() {
let start_event = let start_event =
usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None); usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None);
tokio::spawn(async move { tokio::spawn(async move {
// send start event
start_event.send().await; start_event.send().await;
let mut last_report = Instant::now();
while !stop_usage_thread_clone.load(Ordering::Relaxed) {
if last_report.elapsed() > Duration::from_secs(900) {
let report_event = usage_stats::UsageStatsEvent::new(
ua.clone(),
usage_stats::EventType::Ping,
None,
);
report_event.send().await;
last_report = Instant::now();
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
}); });
}; };
let compat_return_full_text = match &model_info.pipeline_tag { let compat_return_full_text = match &model_info.pipeline_tag {
@ -1926,7 +1946,7 @@ pub async fn run(
validation_workers, validation_workers,
api_key, api_key,
config, config,
(tokenizer, tokenizer_config), (tokenizer?, tokenizer_config),
(preprocessor_config, processor_config), (preprocessor_config, processor_config),
hostname, hostname,
port, port,
@ -1943,6 +1963,7 @@ pub async fn run(
.await; .await;
if let Some(ua) = user_agent { if let Some(ua) = user_agent {
stop_usage_thread.store(true, Ordering::Relaxed);
match result { match result {
Ok(_) => { Ok(_) => {
let stop_event = usage_stats::UsageStatsEvent::new( let stop_event = usage_stats::UsageStatsEvent::new(
@ -2419,8 +2440,13 @@ async fn start(
} }
} else { } else {
// Run server // Run server
let listener = match tokio::net::TcpListener::bind(&addr).await {
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); Ok(listener) => listener,
Err(e) => {
tracing::error!("Failed to bind to {addr}: {e}");
return Err(WebServerError::Axum(Box::new(e)));
}
};
axum::serve(listener, app) axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal()) .with_graceful_shutdown(shutdown_signal())
.await .await
@ -2535,4 +2561,6 @@ impl From<InferError> for Event {
pub enum WebServerError { pub enum WebServerError {
#[error("Axum error: {0}")] #[error("Axum error: {0}")]
Axum(#[from] axum::BoxError), Axum(#[from] axum::BoxError),
#[error("Tokenizer error: {0}")]
Tokenizer(String),
} }

View File

@ -43,6 +43,7 @@ pub enum EventType {
Start, Start,
Stop, Stop,
Error, Error,
Ping,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
@ -70,7 +71,7 @@ impl UsageStatsEvent {
.post(TELEMETRY_URL) .post(TELEMETRY_URL)
.headers(headers) .headers(headers)
.body(body) .body(body)
.timeout(Duration::from_secs(5)) .timeout(Duration::from_secs(10))
.send() .send()
.await; .await;
} }
@ -96,6 +97,7 @@ pub struct Args {
disable_grammar_support: bool, disable_grammar_support: bool,
max_client_batch_size: usize, max_client_batch_size: usize,
usage_stats_level: UsageStatsLevel, usage_stats_level: UsageStatsLevel,
backend_name: &'static str,
} }
impl Args { impl Args {
@ -119,6 +121,7 @@ impl Args {
disable_grammar_support: bool, disable_grammar_support: bool,
max_client_batch_size: usize, max_client_batch_size: usize,
usage_stats_level: UsageStatsLevel, usage_stats_level: UsageStatsLevel,
backend_name: &'static str,
) -> Self { ) -> Self {
Self { Self {
model_config, model_config,
@ -139,6 +142,7 @@ impl Args {
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
usage_stats_level, usage_stats_level,
backend_name,
} }
} }
} }

View File

@ -1229,12 +1229,11 @@ mod tests {
assert!( assert!(
chunks chunks
== vec![ == vec![
Chunk::Text("test".to_string()).into(), Chunk::Text("test".to_string()),
Chunk::Image(Image { Chunk::Image(Image {
data: pixel_data.clone(), data: pixel_data.clone(),
mimetype: "image/gif".to_string() mimetype: "image/gif".to_string()
}) })
.into()
], ],
"Failed to process images", "Failed to process images",
); );
@ -1289,17 +1288,15 @@ mod tests {
assert!( assert!(
chunks chunks
== vec![ == vec![
Chunk::Text("test".to_string()).into(), Chunk::Text("test".to_string()),
Chunk::Image(Image {
data: pixel_data.clone(),
mimetype: "image/gif".to_string()
}),
Chunk::Image(Image { Chunk::Image(Image {
data: pixel_data.clone(), data: pixel_data.clone(),
mimetype: "image/gif".to_string() mimetype: "image/gif".to_string()
}) })
.into(),
Chunk::Image(Image {
data: pixel_data.clone(),
mimetype: "image/gif".to_string()
})
.into()
], ],
"Failed to process images", "Failed to process images",
); );

View File

@ -1,5 +1,5 @@
[toolchain] [toolchain]
# Released on: June 13, 2024 # Released on: June 13, 2024
# https://releases.rs/docs/1.79.0/ # https://releases.rs/docs/1.79.0/
channel = "1.80.1" channel = "1.84.0"
components = ["rustfmt", "clippy"] components = ["rustfmt", "clippy"]

View File

@ -9,11 +9,14 @@ include Makefile-exllamav2
include Makefile-flashinfer include Makefile-flashinfer
unit-tests: unit-tests:
pip install -U pip uv
uv pip install -e ".[dev]"
pytest -s -vv -m "not private" tests pytest -s -vv -m "not private" tests
gen-server: gen-server:
# Compile protos # Compile protos
pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir pip install -U pip uv
uv pip install ".[gen]"
mkdir text_generation_server/pb || true mkdir text_generation_server/pb || true
python -m grpc_tools.protoc -I../proto/v3 --python_out=text_generation_server/pb \ python -m grpc_tools.protoc -I../proto/v3 --python_out=text_generation_server/pb \
--grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/v3/generate.proto --grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/v3/generate.proto
@ -21,24 +24,14 @@ gen-server:
touch text_generation_server/pb/__init__.py touch text_generation_server/pb/__init__.py
install-server: gen-server install-server: gen-server
pip install pip --upgrade uv pip install -e ".[accelerate, compressed-tensors, quantize, peft, outlines]"
pip install -r requirements_cuda.txt
pip install -e ".[accelerate, compressed-tensors, quantize, peft, outlines]"
install: install-cuda install: install-cuda
echo "Installed server" echo "Installed server"
install-cuda: install-server install-flash-attention-v2-cuda install-flash-attention install-cuda: install-server install-flash-attention-v2-cuda install-flash-attention
pip install -e ".[attention,bnb,marlin,moe]" uv pip install -e ".[attention,bnb,marlin,moe]"
pip install nvidia-nccl-cu12==2.22.3 uv pip install nvidia-nccl-cu12==2.22.3
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
export-requirements:
poetry export -o requirements_cuda.txt --without-hashes
poetry export -o requirements_rocm.txt --without-hashes
poetry export -o requirements_intel.txt --without-hashes

View File

@ -1,5 +1,6 @@
install-flashinfer: install-flashinfer:
# We need fsspec as an additional dependency, but # We need fsspec as an additional dependency, but
# `pip install flashinfer` cannot resolve it. # `pip install flashinfer` cannot resolve it.
pip install fsspec pip install fsspec sympy==1.13.1 numpy
pip install flashinfer==0.2.0.post1 -i https://flashinfer.ai/whl/cu124/torch2.4 pip install -U setuptools
TORCH_CUDA_ARCH_LIST="8.0;8.6;8.9;9.0+PTX" FLASHINFER_ENABLE_AOT=1 pip install git+https://github.com/flashinfer-ai/flashinfer.git@v0.2.0.post1#egg=flashinfer --no-build-isolation

4100
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,96 +1,87 @@
[tool.poetry] [project]
name = "text-generation-server" name = "text-generation-server"
version = "2.0.5-dev0" version = "2.0.5-dev0"
description = "Text Generation Inference Python gRPC Server" description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] readme = "README.md"
requires-python = ">=3.9"
[tool.poetry.scripts] authors = [
text-generation-server = 'text_generation_server.cli:app' {name = "Olivier Dehaene", email = "olivier@huggingface.co"},
{name = "Nicolas Patry", email = "nicolas@huggingface.co"},
[tool.poetry.dependencies]
python = ">=3.9,<3.13"
protobuf = ">=4.25.3,<6"
grpcio = "^1.51.1"
grpcio-status = "^1.51.1"
grpcio-reflection = "^1.51.1"
grpc-interceptor = "^0.15.4"
typer = "^0.12.5"
accelerate = {version = "^1.1.0", optional = true}
bitsandbytes = { version = "^0.43.0", optional = true }
safetensors = "^0.4.5"
loguru = "^0.7.2"
opentelemetry-api = "^1.27.0"
opentelemetry-exporter-otlp = "^1.27.0"
opentelemetry-instrumentation-grpc = "^0.48b0"
hf-transfer = "^0.1.2"
sentencepiece = "^0.2.0"
tokenizers = "^0.20.3"
huggingface-hub = "^0.23"
transformers = "^4.46.2"
einops = "^0.8.0"
texttable = { version = "^1.6.7", optional = true }
datasets = {version = "^2.21.0", optional = true}
peft = {version = "^0.13.2", optional = true}
torch = {version = "^2.4.1", optional = true}
scipy = "^1.13.1"
pillow = "^11.0.0"
outlines= {version = "^0.1.3", optional = true}
prometheus-client = ">=0.20.0,<0.22"
py-cpuinfo = "^9.0.0"
compressed-tensors = {version = "^0.7.1", optional = true}
# Remove later, temporary workaround for outlines.
numpy = "^1.26.4"
attention-kernels = [
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
] ]
marlin-kernels = [ dependencies = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, "einops>=0.8.0",
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, "grpc-interceptor>=0.15.4",
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, "grpcio>=1.67.0",
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, "grpcio-reflection>=1.67.0",
"grpcio-status>=1.67.0",
"hf-transfer>=0.1.8",
"loguru>=0.7.3",
"numpy>=1.26,<3",
"opentelemetry-api>=1.27.0",
"opentelemetry-exporter-otlp>=1.27.0",
"opentelemetry-instrumentation-grpc>=0.50b0",
"pillow>=11.1.0",
"prometheus-client>=0.21.0",
"protobuf>=5.28.3",
"py-cpuinfo>=9.0.0",
"rich>=13.8.1",
"safetensors>=0.4.5",
"scipy>=1.13.1",
"sentencepiece>=0.2.0",
"tokenizers>=0.20.3",
"typer>=0.15.1",
] ]
moe-kernels = [
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
]
rich = "^13.8.1"
[tool.poetry.extras] [project.scripts]
torch = ["torch"] text-generation-server = "text_generation_server.cli:app"
accelerate = ["accelerate"]
attention = ["attention-kernels"] [project.optional-dependencies]
bnb = ["bitsandbytes"] accelerate = [
compressed-tensors = ["compressed-tensors"] "accelerate>=1.2.1,<2",
marlin = ["marlin-kernels"] ]
bnb = [
"bitsandbytes>=0.45.0",
]
compressed-tensors = [
"compressed-tensors>=0.9.0",
]
peft = [
"peft>=0.14.0",
]
outlines = [
"outlines>=0.1.13",
]
dev = [
"grpcio-tools>=1.51.1,<2.0",
"pytest>=7.3.0,<8"
]
quantize = [
"texttable>=1.6.7,<2",
"datasets>=2.21,<3",
]
moe = [ "moe-kernels" ] moe = [ "moe-kernels" ]
peft = ["peft"] attention = [ "attention-kernels" ]
quantize = ["texttable", "datasets", "accelerate"] marlin = [ "marlin-kernels" ]
outlines = ["outlines"] gen = [
"grpcio-tools>=1.69.0",
"mypy-protobuf>=3.6.0",
]
[tool.poetry.group.dev.dependencies] [tool.uv.sources]
grpcio-tools = "^1.51.1" attention-kernels.url = "https://github.com/danieldk/attention-kernels/releases/download/v0.2.0.post2/attention_kernels-0.2.0.post2+cu123torch2.5-cp39-abi3-linux_x86_64.whl"
pytest = "^7.3.0" marlin-kernels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp39-cp39-linux_x86_64.whl", marker = "python_version == '3.9'" },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp310-cp310-linux_x86_64.whl", marker = "python_version == '3.10'" },
[[tool.poetry.source]] { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp311-cp311-linux_x86_64.whl", marker = "python_version == '3.11'" },
name = "pytorch-gpu-src" { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp312-cp312-linux_x86_64.whl", marker = "python_version == '3.12'" },
url = "https://download.pytorch.org/whl/cu121" ]
priority = "explicit" moe-kernels.url = "https://github.com/danieldk/moe-kernels/releases/download/v0.8.2/moe_kernels-0.8.2+cu123torch2.5-cp39-abi3-linux_x86_64.whl"
[tool.pytest.ini_options] [tool.pytest.ini_options]
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"] markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
[build-system]
requires = [
"poetry-core>=1.0.0",
]
build-backend = "poetry.core.masonry.api"
[tool.isort] [tool.isort]
profile = "black" profile = "black"
[tool.setuptools.packages.find]
include = ["text_generation_server*"]

View File

@ -94,6 +94,8 @@ def test_get_mlp_weights_with_gate_up_proj():
# assert the result # assert the result
expected = { expected = {
(3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc),
(3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj),
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj), (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj), (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
@ -188,6 +190,8 @@ def test_get_mlp_weights_llama_compatibility():
result = get_mlp_weights(3, mock_layer) result = get_mlp_weights(3, mock_layer)
expected = { expected = {
(3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc),
(3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj),
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj), (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj), (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
@ -240,6 +244,8 @@ def test_get_mlp_weights_gemma_compatibility():
result = get_mlp_weights(3, mock_layer) result = get_mlp_weights(3, mock_layer)
expected = { expected = {
(3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc),
(3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj),
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj), (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj), (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),

View File

@ -6,9 +6,11 @@ from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple, Type, Union from typing import Dict, List, Optional, Set, Tuple, Type, Union
from loguru import logger
import torch import torch
from peft import LoraConfig as _LoraConfig from peft import LoraConfig as _LoraConfig
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from text_generation_server.utils.log import log_master
from text_generation_server.adapters.config import AdapterConfig, ModuleMap from text_generation_server.adapters.config import AdapterConfig, ModuleMap
@ -203,8 +205,17 @@ class LoraWeights(AdapterWeights):
lora_a_list = [None] * nlayers lora_a_list = [None] * nlayers
lora_b_list = [None] * nlayers lora_b_list = [None] * nlayers
# import ipdb; ipdb.set_trace()
for layer_id in range(nlayers): for layer_id in range(nlayers):
key = (layer_id, layer_type) key = (layer_id, layer_type)
if key not in target_to_layer:
# There is no layer of this type in the model
log_master(
logger.warning,
f"Key specified in lora weights but not found in base model: {key}",
)
return None
weight_name, layer = target_to_layer[key] weight_name, layer = target_to_layer[key]
base_weight = layer.base_layer.linear.weight base_weight = layer.base_layer.linear.weight
base_device = base_weight.device base_device = base_weight.device

View File

@ -9,6 +9,8 @@ from enum import Enum
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from text_generation_server.utils.adapter import parse_lora_adapters from text_generation_server.utils.adapter import parse_lora_adapters
# Dummy change should cache hit.
app = typer.Typer() app = typer.Typer()

View File

@ -111,6 +111,8 @@ def paged_attention(
out = torch.empty_like(query) out = torch.empty_like(query)
kv_cache_dtype = "fp8" if kv_cache.dtype == torch.float8_e4m3fn else "auto"
use_v1 = max_s <= 8192 and ( use_v1 = max_s <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512 max_num_partitions == 1 or num_seqs * num_heads > 512
) )
@ -120,15 +122,16 @@ def paged_attention(
query, query,
kv_cache.key, kv_cache.key,
kv_cache.value, kv_cache.value,
kv_head_mapping, kv_cache.key.shape[1],
softmax_scale, softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache_dtype,
1.0, kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
) )
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
@ -153,15 +156,16 @@ def paged_attention(
query, query,
kv_cache.key, kv_cache.key,
kv_cache.value, kv_cache.value,
kv_head_mapping, kv_cache.key.shape[1],
softmax_scale, softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache_dtype,
1.0, kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
) )
return out return out
@ -235,7 +239,6 @@ def attention(
paged_kv_cache=(kv_cache.key, kv_cache.value), paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap, logits_soft_cap=softcap,
sm_scale=softmax_scale, sm_scale=softmax_scale,
window_left=window_size_left,
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0, k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0, v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
) )

View File

@ -84,7 +84,7 @@ def use_prefill_with_paged_kv_state(
token = prefill_with_paged_kv_state.set(state) token = prefill_with_paged_kv_state.set(state)
try: try:
state.begin_forward( state.plan(
qo_indptr=cu_seqlens, qo_indptr=cu_seqlens,
paged_kv_indptr=indptr, paged_kv_indptr=indptr,
paged_kv_indices=block_tables, paged_kv_indices=block_tables,
@ -99,7 +99,6 @@ def use_prefill_with_paged_kv_state(
) )
yield yield
finally: finally:
state.end_forward()
if token is not None: if token is not None:
prefill_with_paged_kv_state.reset(token) prefill_with_paged_kv_state.reset(token)
@ -200,7 +199,7 @@ def use_decode_state(
token = decode_state.set(state) token = decode_state.set(state)
try: try:
state.begin_forward( state.plan(
indptr=indptr, indptr=indptr,
indices=block_tables, indices=block_tables,
last_page_len=last_page_len, last_page_len=last_page_len,
@ -214,6 +213,5 @@ def use_decode_state(
) )
yield yield
finally: finally:
state.end_forward()
if token is not None: if token is not None:
decode_state.reset(token) decode_state.reset(token)

View File

@ -1,9 +1,12 @@
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
import torch import torch
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from typing import Optional from typing import Optional
from text_generation_server.models.globals import (
ATTENTION,
BLOCK_SIZE,
)
SUPPORTS_WINDOWING = False SUPPORTS_WINDOWING = False
@ -28,6 +31,22 @@ def attention(
out = torch.empty_like(query) out = torch.empty_like(query)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
if ATTENTION == "flashdecoding-ipex":
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query.contiguous() if query.device.type == "xpu" else query,
kv_cache.key,
kv_cache.value,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
seqlen.max_q,
seqlen.max_k,
softmax_scale,
causal,
block_tables,
None,
)
else:
ipex.llm.functional.varlen_attention( ipex.llm.functional.varlen_attention(
query.contiguous() if query.device.type == "xpu" else query, query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key, key.contiguous() if key.device.type == "xpu" else key,
@ -64,6 +83,23 @@ def paged_attention(
raise NotImplementedError("softcap is not available in IPEX") raise NotImplementedError("softcap is not available in IPEX")
out = torch.empty_like(query) out = torch.empty_like(query)
if ATTENTION == "flashdecoding-ipex":
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query.contiguous() if query.device.type == "xpu" else query,
kv_cache.key,
kv_cache.value,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
seqlen.max_q,
seqlen.max_k,
softmax_scale,
True,
block_tables,
None,
)
else:
input_lengths = seqlen.input_lengths + seqlen.cache_lengths input_lengths = seqlen.input_lengths + seqlen.cache_lengths
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out, out,

View File

@ -52,12 +52,17 @@ class KVCache:
device: torch.device, device: torch.device,
): ):
"""Construct the key-value cache for a layer.""" """Construct the key-value cache for a layer."""
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and ( if not (
ATTENTION != "flashinfer" or SYSTEM != "cuda" (ATTENTION == "flashinfer" and SYSTEM == "cuda")
or (ATTENTION == "paged" and SYSTEM in ("cuda", "rocm"))
): ):
raise ValueError( raise ValueError(
"FP8 KV cache is currently only supported for flashinfer on CUDA" "FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on CUDA and ROCm. "
)
if SYSTEM == "rocm" and dtype == torch.float8_e5m2:
raise ValueError(
"float8_e5m2 FP8 KV cache is not supported on AMD ROCm"
) )
element_size = torch.tensor([], dtype=dtype).element_size() element_size = torch.tensor([], dtype=dtype).element_size()
@ -66,7 +71,9 @@ class KVCache:
else: else:
x = BLOCK_SIZE // element_size x = BLOCK_SIZE // element_size
if ATTENTION in {"flashdecoding", "flashinfer"}: if ATTENTION in {"flashdecoding", "flashinfer"} or (
ATTENTION == "flashdecoding-ipex" and device.type == "xpu"
):
self.kv_cache = ( self.kv_cache = (
torch.empty( torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size), (num_blocks, BLOCK_SIZE, num_heads, head_size),
@ -80,6 +87,7 @@ class KVCache:
), ),
) )
elif SYSTEM == "ipex" and device == torch.device("cpu"): elif SYSTEM == "ipex" and device == torch.device("cpu"):
# ipex cpu flashdecoding kernel and paged attention kernel share same layout
self.kv_cache = ( self.kv_cache = (
torch.empty( torch.empty(
(num_blocks, num_heads, BLOCK_SIZE, head_size), (num_blocks, num_heads, BLOCK_SIZE, head_size),
@ -110,21 +118,17 @@ class KVCache:
"""Check if the cache can be scaled by the given scales.""" """Check if the cache can be scaled by the given scales."""
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0: if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
return False return False
elif ( elif self.dtype == torch.float8_e4m3fn and (
self.dtype == torch.float8_e4m3fn (ATTENTION == "flashinfer" and SYSTEM == "cuda")
and ATTENTION == "flashinfer" or (ATTENTION == "paged" and SYSTEM == "rocm")
and SYSTEM == "cuda"
): ):
log_once( log_once(logger.info, "Using FP8 KV cache scales")
logger.info,
"Using FP8 KV cache scales",
)
return True return True
else: else:
# We have scales, but not the correct FP8 cache type, so warn once. # We have scales, but not the correct FP8 cache type, so warn once.
log_once( log_once(
logger.info, logger.info,
"Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported", "Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm",
) )
return False return False
@ -158,7 +162,7 @@ class KVCache:
key_cache = self.kv_cache[0] key_cache = self.kv_cache[0]
value_cache = self.kv_cache[1] value_cache = self.kv_cache[1]
if self.can_scale(kv_scales): if self.can_scale(kv_scales) and SYSTEM == "cuda":
if kv_scales.key_scale_cpu != 1.0: if kv_scales.key_scale_cpu != 1.0:
key = fp8_quantize( key = fp8_quantize(
key.float(), key.float(),
@ -187,8 +191,22 @@ class KVCache:
shape = key_cache.shape shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value value_cache.view(-1, shape[-2], shape[-1])[slots] = value
elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu":
import intel_extension_for_pytorch as ipex
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
key, value, key_cache, value_cache, slots
)
else: else:
paged_reshape_and_cache(key, value, key_cache, value_cache, slots) paged_reshape_and_cache(
key,
value,
key_cache,
value_cache,
slots,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
)
def paged_reshape_and_cache( def paged_reshape_and_cache(
@ -197,7 +215,10 @@ def paged_reshape_and_cache(
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
): ):
if SYSTEM == "cuda": if SYSTEM == "cuda":
try: try:
import attention_kernels import attention_kernels
@ -205,8 +226,13 @@ def paged_reshape_and_cache(
raise ImportError( raise ImportError(
f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}" f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}"
) )
kv_cache_dtype = "auto"
if key_cache.dtype == torch.float8_e4m3fn:
kv_cache_dtype = "fp8"
attention_kernels.reshape_and_cache( attention_kernels.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0 key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale
) )
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
try: try:
@ -215,8 +241,15 @@ def paged_reshape_and_cache(
raise ImportError( raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
) )
kv_cache_dtype = "auto"
if key_cache.dtype == torch.float8_e4m3fn:
key_cache = key_cache.view(torch.uint8)
value_cache = value_cache.view(torch.uint8)
kv_cache_dtype = "fp8"
ops.reshape_and_cache( ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0 key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale
) )
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex

View File

@ -133,6 +133,15 @@ def paged_attention(
out = torch.empty_like(query) out = torch.empty_like(query)
if kv_cache.dtype == torch.float8_e4m3fn:
key = kv_cache.key.view(torch.uint8)
value = kv_cache.value.view(torch.uint8)
kv_cache_dtype = "fp8"
else:
key = kv_cache.key
value = kv_cache.value
kv_cache_dtype = "auto"
# NOTE(woosuk): We use a simple heuristic to decide whether to use # NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use # PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of # V1 to avoid the overhead of reduction. Also, if the number of
@ -147,8 +156,8 @@ def paged_attention(
ops.paged_attention_v1( ops.paged_attention_v1(
out, out,
query, query,
kv_cache.key, key,
kv_cache.value, value,
num_kv_heads, num_kv_heads,
softmax_scale, softmax_scale,
block_tables, block_tables,
@ -156,24 +165,24 @@ def paged_attention(
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache_dtype,
1.0, kv_scales.key_scale_cpu,
1.0, kv_scales.value_scale_cpu,
) )
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0 assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty( tmp_output = torch.zeros(
size=(num_seqs, num_heads, max_num_partitions, head_size), size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype, dtype=out.dtype,
device=out.device, device=out.device,
) )
exp_sums = torch.empty( exp_sums = torch.zeros(
size=(num_seqs, num_heads, max_num_partitions), size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32, dtype=torch.float32,
device=out.device, device=out.device,
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.zeros_like(exp_sums)
if not use_custom: if not use_custom:
ops.paged_attention_v2( ops.paged_attention_v2(
@ -182,8 +191,8 @@ def paged_attention(
max_logits, max_logits,
tmp_output, tmp_output,
query, query,
kv_cache.key, key,
kv_cache.value, value,
num_kv_heads, num_kv_heads,
softmax_scale, softmax_scale,
block_tables, block_tables,
@ -191,9 +200,9 @@ def paged_attention(
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache_dtype,
1.0, kv_scales.key_scale_cpu,
1.0, kv_scales.value_scale_cpu,
) )
else: else:
ops.paged_attention_rocm( ops.paged_attention_rocm(
@ -202,8 +211,8 @@ def paged_attention(
max_logits, max_logits,
tmp_output, tmp_output,
query, query,
kv_cache.key, key,
kv_cache.value, value,
num_kv_heads, num_kv_heads,
softmax_scale, softmax_scale,
block_tables, block_tables,
@ -211,9 +220,9 @@ def paged_attention(
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache_dtype,
1.0, kv_scales.key_scale_cpu,
1.0, kv_scales.value_scale_cpu,
None, None,
_PARTITION_SIZE, _PARTITION_SIZE,
) )

View File

@ -3,8 +3,14 @@ from typing import List, Optional, Union
import torch import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationType from compressed_tensors.quantization import QuantizationArgs, QuantizationType
from text_generation_server.layers.fp8 import Fp8Weight, _load_scalar_or_matrix_scale from text_generation_server.layers.fp8 import (
Fp8Weight,
_load_scalar_or_matrix_scale,
requantize_with_max_scale,
normalize_e4m3fn_to_native_float8,
)
from text_generation_server.utils.weights import Weights, WeightsLoader from text_generation_server.utils.weights import Weights, WeightsLoader
from text_generation_server.utils.import_utils import SYSTEM
class W8ANFpLoader(WeightsLoader): class W8ANFpLoader(WeightsLoader):
@ -47,11 +53,10 @@ class W8ANFpLoader(WeightsLoader):
weight_scale = None weight_scale = None
if self.load_weight_scale: if self.load_weight_scale:
weight_scale = ( weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1) if SYSTEM == "cuda":
.expand(w.shape[0]) weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
)
input_scale = None input_scale = None
if self.load_input_scale: if self.load_input_scale:
@ -87,6 +92,7 @@ class W8ANFpLoader(WeightsLoader):
block_sizes=block_sizes, block_sizes=block_sizes,
to_dtype=False, to_dtype=False,
) )
if SYSTEM == "cuda":
weight_scale = weight_scale.reshape(-1).expand(w.shape[0]) weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
input_scale = None input_scale = None
@ -141,6 +147,17 @@ class W8ANFpLoader(WeightsLoader):
else None else None
) )
if self.load_weight_scale and SYSTEM == "rocm":
w, weight_scale, input_scale = normalize_e4m3fn_to_native_float8(
w, weight_scale, input_scale
)
if weight_scale.numel() == len(prefixes):
logical_widths = [x[0] for x in shapes]
w, weight_scale = requantize_with_max_scale(
w, weight_scale.to(weights.device), logical_widths, weights.dtype
)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=weight_scale, weight_scale=weight_scale,
@ -153,11 +170,10 @@ class W8ANFpLoader(WeightsLoader):
w = weights.get_sharded(f"{prefix}.weight", dim=1) w = weights.get_sharded(f"{prefix}.weight", dim=1)
weight_scale = None weight_scale = None
if self.load_weight_scale: if self.load_weight_scale:
weight_scale = ( weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1) if SYSTEM == "cuda":
.expand(w.shape[0]) weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
)
input_scale = None input_scale = None
if self.load_input_scale: if self.load_input_scale:

View File

@ -19,6 +19,15 @@ try:
except ImportError: except ImportError:
marlin_kernels = None marlin_kernels = None
try:
from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8
except ImportError:
w8a8_block_fp8_matmul = None
per_token_group_quant_fp8 = None
quant_dtype: torch.dtype = (
torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn
)
if SYSTEM == "cuda" and marlin_kernels is not None: if SYSTEM == "cuda" and marlin_kernels is not None:
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
@ -35,7 +44,6 @@ def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
""" """
if SYSTEM == "cuda": if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability() major, _ = torch.cuda.get_device_capability()
# Marlin is W8A16, use it when: # Marlin is W8A16, use it when:
# #
@ -49,18 +57,28 @@ def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
# gives better decoding throughput on L4 and L40. # gives better decoding throughput on L4 and L40.
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
if major == 8 and minor == 9:
log_once(
logger.info,
"GPU supports FP8, but using Marlin FP8 kernel for better performance",
)
else:
log_once(
logger.info, "GPU does not support FP8, using Marlin FP8 kernel"
)
return GPTQMarlinFP8Linear return GPTQMarlinFP8Linear
# On other systems let Torch decide if the hardware supports FP8. # On other systems let Torch decide if the hardware supports FP8.
return Fp8Linear return Fp8Linear
def normalize_e4m3fn_to_e4m3fnuz( def normalize_e4m3fn_to_native_float8(
weight: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor, weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None, input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
assert weight.dtype == torch.float8_e4m3fn if weight.dtype == torch.float8_e4m3fn and SYSTEM == "rocm":
# The bits pattern 10000000(-128) represents zero in e4m3fn # The bits pattern 10000000(-128) represents zero in e4m3fn
# but NaN in e4m3fnuz. So here we set it to 0. # but NaN in e4m3fnuz. So here we set it to 0.
# https://onnx.ai/onnx/technical/float8.html # https://onnx.ai/onnx/technical/float8.html
@ -79,6 +97,39 @@ def normalize_e4m3fn_to_e4m3fnuz(
return weight, weight_scale, input_scale return weight, weight_scale, input_scale
def per_tensor_dequantize(
tensor: torch.Tensor,
inv_scale: Union[float, torch.Tensor],
dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
fake_qweight = tensor.to(dtype)
dq_weight = fake_qweight * inv_scale
return dq_weight
def requantize_with_max_scale(
weight: torch.Tensor,
weight_scale: torch.Tensor,
logical_widths: int,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max().float()
start = 0
for idx, logical_width in enumerate(logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(
weight[start:end, :], weight_scale[idx], dtype
)
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
weight_dq, max_w_scale
)
start = end
return weight, max_w_scale_normalized
def fp8_quantize( def fp8_quantize(
weight: torch.Tensor, weight: torch.Tensor,
scale: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None,
@ -96,7 +147,7 @@ def fp8_quantize(
shape = weight.shape shape = weight.shape
qweight, scale = marlin_kernels.scaled_fp8_quant( qweight, scale = marlin_kernels.scaled_fp8_quant(
weight.reshape(-1, shape[-1]), weight.reshape(-1, shape[-1]),
dtype=qdtype, dtype=quant_dtype,
scale=scale, scale=scale,
scale_ub=scale_upper_bound, scale_ub=scale_upper_bound,
# TODO: don't do this when we have to use the Torch kernel. # TODO: don't do this when we have to use the Torch kernel.
@ -116,6 +167,8 @@ def fp8_quantize(
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
scale = scale.float().reciprocal() scale = scale.float().reciprocal()
else: else:
if SYSTEM == "rocm":
scale = scale / 2.0
# Use reciprocal to avoid more expensive division. # Use reciprocal to avoid more expensive division.
qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max) qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
@ -124,7 +177,7 @@ def fp8_quantize(
qweight = qweight.to(qdtype) qweight = qweight.to(qdtype)
if SYSTEM == "rocm": if SYSTEM == "rocm":
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale) qweight, scale, _ = normalize_e4m3fn_to_native_float8(qweight, scale)
return qweight, scale return qweight, scale
@ -132,26 +185,42 @@ def fp8_quantize(
class HybridFP8UnquantLoader(WeightsLoader): class HybridFP8UnquantLoader(WeightsLoader):
"""Weight loader that loads FP8 and unquantized Torch tensors.""" """Weight loader that loads FP8 and unquantized Torch tensors."""
def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool): def __init__(
self,
activation_scale_ub: Optional[float],
to_fp8: bool,
weight_block_size: Optional[List[int]] = None,
):
self.activation_scale_ub = activation_scale_ub self.activation_scale_ub = activation_scale_ub
self.to_fp8 = to_fp8 self.to_fp8 = to_fp8
self.weight_block_size = weight_block_size
def get_weights(self, weights: "Weights", prefix: str): def get_weights(self, weights: "Weights", prefix: str):
w = weights.get_tensor(f"{prefix}.weight") w = weights.get_tensor(f"{prefix}.weight")
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
# FP8 branch if self.weight_block_size is not None:
scale = ( scale = weights.get_tensor(f"{prefix}.weight_scale_inv")
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) return Fp8Weight(
.reshape(-1) weight=w,
.expand(w.shape[0]) weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
) )
# FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if SYSTEM == "cuda":
scale.reshape(-1).expand(w.shape[0])
input_scale = None input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"): if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = weights.get_tensor( input_scale = (
f"{prefix}.input_scale", to_dtype=False weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
).reshape(-1) .reshape(-1)
.max()
)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
@ -178,6 +247,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
# FP8 branch # FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if scale.numel() > 1: if scale.numel() > 1:
scale = weights.get_packed_sharded( scale = weights.get_packed_sharded(
f"{prefix}.weight_scale", f"{prefix}.weight_scale",
@ -185,6 +255,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
block_sizes=block_sizes, block_sizes=block_sizes,
to_dtype=False, to_dtype=False,
) )
if SYSTEM == "cuda":
scale = scale.reshape(-1).expand(w.shape[0]) scale = scale.reshape(-1).expand(w.shape[0])
input_scale = None input_scale = None
@ -225,6 +296,21 @@ class HybridFP8UnquantLoader(WeightsLoader):
# FP8 branch # FP8 branch
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
scale = [
weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False)
for p in prefixes
]
scale = torch.cat(scale, dim=dim)
scale = scale.to(weights.device)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
scale = [ scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p, shape in zip(prefixes, shapes) for p, shape in zip(prefixes, shapes)
@ -243,6 +329,17 @@ class HybridFP8UnquantLoader(WeightsLoader):
else None else None
) )
if SYSTEM == "rocm":
w, scale, input_scale = normalize_e4m3fn_to_native_float8(
w, scale, input_scale
)
if scale.numel() == len(prefixes):
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.to(weights.device), logical_widths, weights.dtype
)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
@ -259,16 +356,30 @@ class HybridFP8UnquantLoader(WeightsLoader):
w = weights.get_sharded(f"{prefix}.weight", dim=1) w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch # FP8 branch
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
scale = ( if self.weight_block_size is not None:
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) # XXX: Yes the weights is named scale_inv, but corresponds to scale it seems.
.reshape(-1) scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
.expand(w.shape[0])
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
) )
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if SYSTEM == "cuda":
scale = scale.reshape(-1).expand(w.shape[0])
input_scale = None input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"): if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = weights.get_tensor( input_scale = (
f"{prefix}.input_scale", to_dtype=False weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
).reshape(-1) .reshape(-1)
.max()
)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
@ -291,6 +402,7 @@ class Fp8Weight(Weight):
input_scale: Optional[torch.Tensor] = None input_scale: Optional[torch.Tensor] = None
activation_scale_ub: Optional[float] = None activation_scale_ub: Optional[float] = None
force_w8a16: bool = False force_w8a16: bool = False
weight_block_size: Optional[List[int]] = None
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None: if self.weight_scale is None:
@ -307,6 +419,7 @@ class Fp8Weight(Weight):
bias=bias, bias=bias,
input_scale=self.input_scale, input_scale=self.input_scale,
scale_upper_bound=self.activation_scale_ub, scale_upper_bound=self.activation_scale_ub,
weight_block_size=self.weight_block_size,
) )
@ -321,19 +434,21 @@ class Fp8Linear(torch.nn.Module):
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None, input_scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[float] = None, scale_upper_bound: Optional[float] = None,
weight_block_size: Optional[List[int]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if CUTLASS_FP8_AVAILABLE: if CUTLASS_FP8_AVAILABLE:
log_once(logger.info, "Using cutlass w8a8 kernels") log_once(logger.info, "Using cutlass w8a8 kernels")
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn: if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz( qweight, scale, input_scale = normalize_e4m3fn_to_native_float8(
weight=qweight, weight_scale=scale weight=qweight, weight_scale=scale, input_scale=input_scale
) )
self.dtype = dtype self.dtype = dtype
self.qweight = qweight self.qweight = qweight
self.scale = scale.float() self.scale = scale.float()
self.input_scale = input_scale.float() if input_scale is not None else None self.input_scale = input_scale.float() if input_scale is not None else None
self.weight_block_size = weight_block_size
if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None: if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None:
self.scale_upper_bound = torch.tensor( self.scale_upper_bound = torch.tensor(
@ -367,6 +482,7 @@ class Fp8Linear(torch.nn.Module):
) -> "Fp8Linear": ) -> "Fp8Linear":
input_scale = kwargs.get("input_scale", None) input_scale = kwargs.get("input_scale", None)
scale_upper_bound = kwargs.get("scale_upper_bound", None) scale_upper_bound = kwargs.get("scale_upper_bound", None)
weight_block_size = kwargs.get("weight_block_size", None)
return cls( return cls(
qweight=weight, qweight=weight,
@ -375,6 +491,7 @@ class Fp8Linear(torch.nn.Module):
scale_upper_bound=scale_upper_bound, scale_upper_bound=scale_upper_bound,
bias=bias, bias=bias,
dtype=dtype, dtype=dtype,
weight_block_size=weight_block_size,
) )
@classmethod @classmethod
@ -386,6 +503,25 @@ class Fp8Linear(torch.nn.Module):
return cls._device_identity_cache[device] return cls._device_identity_cache[device]
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.weight_block_size is not None:
# https://arxiv.org/pdf/2412.19437
# At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and
# scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
# group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
# channels).
qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
output = w8a8_block_fp8_matmul(
qinput,
self.qweight,
scale,
self.scale,
self.weight_block_size,
output_dtype=input.dtype,
)
if self.bias is not None:
output = output + self.bias
return output.to(dtype=input.dtype)
if CUTLASS_FP8_AVAILABLE: if CUTLASS_FP8_AVAILABLE:
# cutlass FP8 supports per-token scales, so get non-scalar scales. # cutlass FP8 supports per-token scales, so get non-scalar scales.
qinput, scale = fp8_quantize( qinput, scale = fp8_quantize(
@ -443,6 +579,9 @@ class Fp8Linear(torch.nn.Module):
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size): def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
scale = weights.get_tensor(prefix, to_dtype=False) scale = weights.get_tensor(prefix, to_dtype=False)
if scale.numel() > 1: if scale.numel() > 1:
scale = weights.get_sharded(prefix, dim=0, to_dtype=False) scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
elif SYSTEM == "rocm":
return scale.reshape(-1)
return scale.reshape(-1).expand(shape[0]) return scale.reshape(-1).expand(shape[0])

View File

@ -956,15 +956,24 @@ def quantize(
pack(model, quantizers, bits, groupsize) pack(model, quantizers, bits, groupsize)
from safetensors.torch import save_file from safetensors.torch import save_file
from transformers.modeling_utils import shard_checkpoint from huggingface_hub import split_torch_state_dict_into_shards
state_dict = model.state_dict() state_dict = model.state_dict()
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
max_shard_size = "10GB" max_shard_size = "10GB"
shards, index = shard_checkpoint( state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors" state_dict,
filename_pattern="model.safetensors",
max_shard_size=max_shard_size,
) )
index = None
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
shards = state_dict_split.filename_to_tensors
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
for shard_file, shard in shards.items(): for shard_file, shard in shards.items():
save_file( save_file(

View File

@ -2,14 +2,12 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from loguru import logger
from text_generation_server.layers.fp8 import fp8_quantize from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.layers.marlin.gptq import _check_valid_shape from text_generation_server.layers.marlin.gptq import _check_valid_shape
from text_generation_server.layers.marlin.util import ( from text_generation_server.layers.marlin.util import (
_check_marlin_kernels, _check_marlin_kernels,
permute_scales, permute_scales,
) )
from text_generation_server.utils.log import log_once
try: try:
import marlin_kernels import marlin_kernels
@ -36,8 +34,6 @@ class GPTQMarlinFP8Linear(nn.Module):
_check_marlin_kernels() _check_marlin_kernels()
assert marlin_kernels is not None assert marlin_kernels is not None
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
scales = scales.unsqueeze(0) scales = scales.unsqueeze(0)
if scales.shape[1] == 1: if scales.shape[1] == 1:
out_features, in_features = qweight.shape out_features, in_features = qweight.shape

View File

@ -16,6 +16,7 @@ from text_generation_server.layers.moe.gptq_marlin import (
can_use_marlin_moe_gemm, can_use_marlin_moe_gemm,
) )
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
@ -25,7 +26,7 @@ from text_generation_server.utils.weights import (
) )
if SYSTEM == "ipex": if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE from .fused_moe_ipex import fused_topk, grouped_topk
else: else:
from moe_kernels.fused_moe import fused_topk, grouped_topk from moe_kernels.fused_moe import fused_topk, grouped_topk
@ -51,6 +52,8 @@ class MoELayer(Protocol):
up_proj_name: str = "up_proj", up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj", down_proj_name: str = "down_proj",
hidden_act: str = "silu", hidden_act: str = "silu",
scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None,
): ... ): ...
def forward( def forward(
@ -80,9 +83,14 @@ class DenseMoELayer(nn.Module):
up_proj_name: str = "up_proj", up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj", down_proj_name: str = "down_proj",
hidden_act: str = "silu", hidden_act: str = "silu",
scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None,
): ):
super().__init__() super().__init__()
assert scoring_func is None, "scoring func is not handled"
assert e_score_correction_bias is None, "scoring correction bias is not handled"
log_once( log_once(
logger.info, logger.info,
"No fused layers are available for this model type, using (slower) dense MoE layer", "No fused layers are available for this model type, using (slower) dense MoE layer",
@ -139,10 +147,6 @@ class DenseMoELayer(nn.Module):
) )
for i in range(self.n_experts) for i in range(self.n_experts)
] ]
if SYSTEM == "ipex":
self.ipex_fused_moe = GatedMLPMOE(
W13=self.gate_proj, W2=self.down_proj, W3=self.up_proj, use_prepack=True
)
self.process_group = weights.process_group self.process_group = weights.process_group
@ -155,17 +159,6 @@ class DenseMoELayer(nn.Module):
input_shape = x.shape input_shape = x.shape
x = x.view(-1, input_shape[-1]) x = x.view(-1, input_shape[-1])
if SYSTEM == "ipex":
return self.ipex_fused_moe(
hidden_states=x,
router_logits=gating_output,
top_k=self.topk,
renormalize=self.renormalize,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
if self.n_expert_group is not None and self.topk_group is not None: if self.n_expert_group is not None and self.topk_group is not None:
topk_weights, topk_ids = grouped_topk( topk_weights, topk_ids = grouped_topk(
x, x,
@ -213,16 +206,23 @@ class SparseMoELayer(nn.Module):
topk: int, topk: int,
topk_group: Optional[int], topk_group: Optional[int],
weights: Weights, weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj", gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj", up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj", down_proj_name: str = "down_proj",
): ):
super().__init__() super().__init__()
if ( if (
isinstance(weights.loader, DefaultWeightsLoader) isinstance(weights.loader, DefaultWeightsLoader)
and isinstance(weights.loader.weight_class, UnquantizedWeight) and isinstance(weights.loader.weight_class, UnquantizedWeight)
) or isinstance(weights.loader, HybridFP8UnquantLoader): ) or isinstance(weights.loader, HybridFP8UnquantLoader):
if (
isinstance(weights.loader, HybridFP8UnquantLoader)
and weights.loader.to_fp8
):
cls = FP8SparseMoELayer
else:
cls = UnquantizedSparseMoELayer cls = UnquantizedSparseMoELayer
elif isinstance( elif isinstance(
weights.loader, GPTQMarlinWeightsLoader weights.loader, GPTQMarlinWeightsLoader
@ -250,6 +250,8 @@ class SparseMoELayer(nn.Module):
topk=topk, topk=topk,
topk_group=topk_group, topk_group=topk_group,
weights=weights, weights=weights,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
gate_proj_name=gate_proj_name, gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name, up_proj_name=up_proj_name,
down_proj_name=down_proj_name, down_proj_name=down_proj_name,

View File

@ -0,0 +1,173 @@
from typing import Optional
import torch
import torch.nn as nn
from text_generation_server.utils.weights import Weights
from text_generation_server.layers.fp8 import (
Fp8Weight,
fp8_quantize,
quant_dtype,
normalize_e4m3fn_to_native_float8,
)
try:
from moe_kernels.fused_moe import fused_moe
except Exception:
fused_moe = None
class FP8SparseMoELayer(nn.Module):
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()
assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"
self.n_expert_group = n_expert_group
self.topk = topk
self.topk_group = topk_group
self.renormalize = renormalize
self.weight_block_size = weights.weights_loader.weight_block_size
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
(
self.gate_up_proj,
self.gate_up_proj_weight_scale,
self.gate_up_proj_input_scale,
) = _load_expert_multi_weights_col(
prefix=prefix,
n_experts=n_experts,
gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name,
weights=weights,
)
self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
_load_expert_weights_row(
prefix=prefix,
n_experts=n_experts,
name=down_proj_name,
weights=weights,
)
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return fused_moe(
x,
w1=self.gate_up_proj,
w2=self.down_proj,
gating_output=gating_output,
topk=self.topk,
renormalize=self.renormalize,
inplace=True,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
use_fp8_w8a8=True,
w1_scale=self.gate_up_proj_weight_scale,
w2_scale=self.down_proj_weight_scale,
a1_scale=self.gate_up_proj_input_scale,
a2_scale=self.down_proj_input_scale,
)
def _load_expert_weights(
get_weight_fn,
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
) -> torch.Tensor:
all_weight = None
all_weight_scales = None
max_input_scale = None
for i in range(n_experts):
weight = get_weight_fn(prefix, i, name, weights)
assert isinstance(weight, Fp8Weight)
if all_weight is None:
all_weight = torch.empty(
(n_experts,) + weight.weight.shape,
dtype=quant_dtype,
device=weight.weight.device,
)
if all_weight_scales is None:
all_weight_scales = torch.empty(
(n_experts,) + weight.weight_scale.shape,
dtype=torch.float32,
device=weight.weight.device,
)
if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}:
all_weight[i], all_weight_scales[i], current_input_scale = (
normalize_e4m3fn_to_native_float8(
weight.weight, weight.weight_scale, weight.input_scale
)
)
if current_input_scale is not None:
if max_input_scale is None or current_input_scale > max_input_scale:
max_input_scale = current_input_scale
else:
all_weight[i], all_weight_scales[i] = fp8_quantize(
weight.weight, scalar=True
)
assert all_weight is not None
return all_weight, all_weight_scales, max_input_scale
def _load_expert_multi_weights_col(
*,
prefix: str,
n_experts: int,
gate_proj_name: str,
up_proj_name: str,
weights: Weights,
) -> torch.Tensor:
def get_weight_fn(prefix, i, name, weights):
return weights.get_multi_weights_col(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
return _load_expert_weights(
get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights
)
def _load_expert_weights_row(
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
) -> torch.Tensor:
def get_weight_fn(prefix, i, name, weights):
return weights.get_weights_row(f"{prefix}.{i}.{name}")
return _load_expert_weights(
get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights
)

View File

@ -0,0 +1,65 @@
# coding=utf-8
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
topk_weights = torch.nn.functional.softmax(
gating_output, dim=1, dtype=torch.float32
)
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
if renormalize:
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids

View File

@ -69,7 +69,11 @@ class GPTQMarlinSparseMoELayer(nn.Module):
gate_proj_name: str = "gate_proj", gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj", up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj", down_proj_name: str = "down_proj",
scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None,
): ):
assert scoring_func == "softmax", f"scoring func {scoring_func} is not handled"
assert e_score_correction_bias is None, "scoring correction bias is not handled"
super().__init__() super().__init__()
if not ( if not (

View File

@ -23,6 +23,8 @@ class UnquantizedSparseMoELayer(nn.Module):
topk: int, topk: int,
topk_group: Optional[int], topk_group: Optional[int],
weights: Weights, weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj", gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj", up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj", down_proj_name: str = "down_proj",
@ -37,6 +39,9 @@ class UnquantizedSparseMoELayer(nn.Module):
self.topk = topk self.topk = topk
self.topk_group = topk_group self.topk_group = topk_group
self.renormalize = renormalize self.renormalize = renormalize
self.weight_block_size = weights.weights_loader.weight_block_size
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.gate_up_proj = _load_expert_multi_weights_col( self.gate_up_proj = _load_expert_multi_weights_col(
prefix=prefix, prefix=prefix,
@ -58,17 +63,7 @@ class UnquantizedSparseMoELayer(nn.Module):
) )
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
if SYSTEM == "rocm": if SYSTEM == "ipex":
return fused_moe(
x,
self.gate_up_proj,
self.down_proj,
gating_output,
self.topk,
renormalize=self.renormalize,
inplace=True,
)
elif SYSTEM == "ipex":
return self.ipex_fused_moe( return self.ipex_fused_moe(
hidden_states=x, hidden_states=x,
router_logits=gating_output, router_logits=gating_output,
@ -78,7 +73,6 @@ class UnquantizedSparseMoELayer(nn.Module):
num_expert_group=self.n_expert_group, num_expert_group=self.n_expert_group,
topk_group=self.topk_group, topk_group=self.topk_group,
) )
return fused_moe( return fused_moe(
x, x,
w1=self.gate_up_proj, w1=self.gate_up_proj,
@ -90,6 +84,8 @@ class UnquantizedSparseMoELayer(nn.Module):
use_grouped_topk=self.n_expert_group is not None, use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group, num_expert_group=self.n_expert_group,
topk_group=self.topk_group, topk_group=self.topk_group,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
) )

View File

@ -16,10 +16,12 @@ from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download, HfApi from huggingface_hub import hf_hub_download, HfApi
from typing import Optional, List, Dict from typing import Optional, List, Dict
from pathlib import Path from pathlib import Path
import transformers
from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import ( from text_generation_server.models.custom_modeling.mpt_modeling import (
MPTForCausalLM, MPTForCausalLM,
@ -87,6 +89,10 @@ try:
FlashDeepseekV2ForCausalLM, FlashDeepseekV2ForCausalLM,
DeepseekV2Config, DeepseekV2Config,
) )
from text_generation_server.models.custom_modeling.flash_deepseek_v3_modeling import (
FlashDeepseekV3ForCausalLM,
DeepseekV3Config,
)
from text_generation_server.models.custom_modeling.flash_llama_modeling import ( from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM, FlashLlamaForCausalLM,
) )
@ -178,6 +184,14 @@ except ImportError as e:
if MAMBA_AVAILABLE: if MAMBA_AVAILABLE:
__all__.append(Mamba) __all__.append(Mamba)
FLASH_TRANSFORMERS_BACKEND = torch.cuda.is_available()
try:
from text_generation_server.models.transformers_flash_causal_lm import (
TransformersFlashCausalLM,
)
except ImportError:
FLASH_TRANSFORMERS_BACKEND = False
class ModelType(enum.Enum): class ModelType(enum.Enum):
DEEPSEEK_V2 = { DEEPSEEK_V2 = {
@ -185,6 +199,11 @@ class ModelType(enum.Enum):
"name": "Deepseek V2", "name": "Deepseek V2",
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V2", "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
} }
DEEPSEEK_V3 = {
"type": "deepseek_v3",
"name": "Deepseek V3",
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V3",
}
IDEFICS2 = { IDEFICS2 = {
"type": "idefics2", "type": "idefics2",
"name": "Idefics 2", "name": "Idefics 2",
@ -632,6 +651,40 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == DEEPSEEK_V3:
if FLASH_ATTENTION:
head_size = max(
config_dict.get("qk_nope_dim", 128)
+ config_dict.get("qk_rope_dim", 64),
config_dict.get("v_head_dim", 128),
)
return FlashCausalLM(
model_id=model_id,
model_class=FlashDeepseekV3ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
default_dtype=torch.bfloat16,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DeepseekV3Config,
head_size=head_size,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V3")
)
else:
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == MAMBA: elif model_type == MAMBA:
return Mamba( return Mamba(
model_id, model_id,
@ -683,7 +736,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
) )
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
quantize=quantize, quantize=quantize,
@ -838,7 +891,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
else: else:
return CausalLM.fallback( return TransformersFlashCausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -888,12 +941,43 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif ( elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE:
model_type == LLAMA if FLASH_ATTENTION:
or model_type == BAICHUAN return FlashCausalLM(
or model_type == PHI3 model_id=model_id,
or model_type == GRANITE model_class=FlashLlamaForCausalLM,
): revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
)
else:
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == BAICHUAN:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
@ -919,6 +1003,7 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == GEMMA: if model_type == GEMMA:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
@ -934,6 +1019,15 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
else: else:
@ -985,6 +1079,15 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
else: else:
@ -1088,6 +1191,15 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
else: else:
@ -1113,6 +1225,15 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
else: else:
@ -1138,6 +1259,15 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded: elif sharded:
raise NotImplementedError( raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2") FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
@ -1165,6 +1295,15 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
else: else:
@ -1314,8 +1453,6 @@ def get_model(
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))
if sharded:
raise NotImplementedError("sharded is not supported for AutoModel")
if quantize == "gptq": if quantize == "gptq":
raise NotImplementedError( raise NotImplementedError(
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
@ -1328,8 +1465,19 @@ def get_model(
raise NotImplementedError("Eetq quantization is not supported for AutoModel") raise NotImplementedError("Eetq quantization is not supported for AutoModel")
elif quantize == "exl2": elif quantize == "exl2":
raise NotImplementedError("exl2 quantization is not supported for AutoModel") raise NotImplementedError("exl2 quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM.fallback( # Fast transformers if available
transformers_model_class = getattr(
transformers,
modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(model_type, ""),
None,
)
if (
FLASH_TRANSFORMERS_BACKEND
and transformers_model_class is not None
and transformers_model_class._supports_flex_attn
):
return TransformersFlashCausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1337,6 +1485,10 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if sharded:
raise NotImplementedError("sharded is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
return Seq2SeqLM.fallback( return Seq2SeqLM.fallback(
model_id, model_id,
@ -1449,6 +1601,9 @@ def get_model_with_lora_adapters(
"up_proj", "up_proj",
"down_proj", "down_proj",
"qkv_proj", "qkv_proj",
# add c_* layers used in starcoder2
"c_proj",
"c_fc",
] ]
for layer_name in adapter_layers: for layer_name in adapter_layers:

View File

@ -0,0 +1,676 @@
# coding=utf-8
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Type
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from text_generation_server.layers import (
FastLinear,
SpeculativeHead,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
get_linear,
)
from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
)
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import Weights
if SYSTEM == "rocm":
try:
import vllm._custom_ops as ops
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_ops`. Full error: {e}")
class DeepseekV3Config(PretrainedConfig):
def __init__(
self,
vocab_size=102400,
hidden_size=4096,
intermediate_size=11008,
moe_intermediate_size=1407,
num_hidden_layers=30,
num_attention_heads=32,
num_key_value_heads=32,
n_shared_experts=2,
n_routed_experts=160,
ep_size=1,
routed_scaling_factor=1.0,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
topk_method="gready",
n_group=8,
topk_group=3,
num_experts_per_tok=6,
moe_layer_freq=1,
first_k_dense_replace=0,
norm_topk_prob=False,
scoring_func="softmax",
aux_loss_alpha=0.001,
seq_aux=True,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=100000,
eos_token_id=100001,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
self.aux_loss_alpha = aux_loss_alpha
self.seq_aux = seq_aux
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
if tie_word_embeddings:
raise ValueError(
"tie_word_embeddings is not supported for Deepseek V2 models."
)
if ep_size != 1:
raise ValueError(
f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}"
)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class DeepseekV3Attention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights: Weights,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.kv_lora_rank = config.kv_lora_rank
self.q_lora_rank = config.q_lora_rank
self.qk_nope_head_dim = config.qk_nope_head_dim
self.qk_rope_head_dim = config.qk_rope_head_dim
self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
self.value_head_size = config.v_head_dim
self.head_pad_size = max(self.head_size, self.value_head_size)
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.qk_rope_head_dim,
base=config.rope_theta,
device=weights.device,
)
mscale = get_mscale(
self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
)
self.softmax_scale = self.head_size**-0.5 * mscale * mscale
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
if self.q_lora_rank is None:
self.q_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_proj",
weights=weights,
bias=config.attention_bias,
)
else:
self.q_a_proj = get_linear(
weight=weights.get_weights(f"{prefix}.q_a_proj"),
bias=(
weights.get_tensor(f"{prefix}.q_a_proj.bias")
if config.attention_bias
else None
),
)
self.q_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.q_a_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.q_b_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_b_proj",
weights=weights,
bias=config.attention_bias,
)
self.kv_a_proj_with_mqa = get_linear(
weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"),
bias=(
weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias")
if config.attention_bias
else None
),
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.kv_b_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.kv_b_proj",
weights=weights,
bias=config.attention_bias,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache: KVCache,
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
):
if self.q_lora_rank is None:
query = self.q_proj(hidden_states)
else:
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
query = query.view(-1, self.num_heads, self.head_size)
_, query_pe = torch.split(
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, key_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
)
key_nope, value = torch.split(
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
)
batch_size, heads, head_dim = query_pe.shape
query_pe = (
query_pe.view(batch_size, heads, head_dim // 2, 2)
.transpose(2, 3)
.reshape(batch_size, heads, head_dim)
)
batch_size, heads, head_dim = key_pe.shape
key_pe = (
key_pe.view(batch_size, heads, head_dim // 2, 2)
.transpose(2, 3)
.reshape(batch_size, heads, head_dim)
)
self.rotary_emb(query_pe, key_pe, cos, sin)
query[..., self.qk_nope_head_dim :] = query_pe
key = torch.empty_like(query)
key[..., : self.qk_nope_head_dim] = key_nope
key[..., self.qk_nope_head_dim :] = key_pe
# We need to pad the heads because Flash Attention does not support
# qk and v with different head sizes.
query = torch.nn.functional.pad(
query, (0, self.head_pad_size - self.head_size), value=0
)
key = torch.nn.functional.pad(
key, (0, self.head_pad_size - self.head_size), value=0
)
value = torch.nn.functional.pad(
value, (0, self.head_pad_size - self.value_head_size), value=0
)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
# Remove padding.
attn_output = attn_output[..., : self.value_head_size]
return self.o_proj(
attn_output.reshape(-1, self.num_heads * self.value_head_size)
)
class DeepseekV3MLP(nn.Module):
def __init__(self, prefix: str, config, weights, intermediate_size: int):
super().__init__()
self.hidden_act = config.hidden_act
if self.hidden_act != "silu":
# Bail out because MoE only supports silu.
raise NotImplementedError(
"Currently only `silu` is supported as an activation for Deepseek V2."
)
self.act = ACT2FN[self.hidden_act]
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.intermediate_size = intermediate_size // weights.process_group.size()
# TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
and hidden_states.dtype == torch.float16
and hidden_states.shape[0] == 1
and not self.quantize
):
out = torch.empty(
hidden_states.shape[0],
self.intermediate_size,
dtype=hidden_states.dtype,
device="cuda",
)
ops.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
return self.down_proj(out, reduce=reduce)
else:
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
)
class DeepseekV3MoE(nn.Module):
def __init__(
self,
prefix,
config: DeepseekV3Config,
moe_layer_cls: Type[MoELayer],
weights,
):
super().__init__()
self.hidden_dim = config.hidden_size
self.moe_intermediate_size = (
config.moe_intermediate_size // weights.process_group.size()
)
self.routed_scaling_factor = config.routed_scaling_factor
# Gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
if config.topk_method == "noaux_tc":
self.gate.e_score_correction_bias = torch.zeros(
config.n_routed_experts, device=weights.device
)
else:
self.gate.e_score_correction_bias = None
self.moe_layer = moe_layer_cls(
prefix=f"{prefix}.experts",
n_experts=config.n_routed_experts,
n_expert_group=config.n_group,
renormalize=config.norm_topk_prob,
topk=config.num_experts_per_tok,
topk_group=config.topk_group,
weights=weights,
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
)
assert isinstance(self.moe_layer, MoELayer)
if config.n_shared_experts is not None:
self.shared_experts = DeepseekV3MLP(
prefix=f"{prefix}.shared_experts",
config=config,
weights=weights,
intermediate_size=config.moe_intermediate_size
* config.n_shared_experts,
)
else:
self.shared_experts = None
self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.shared_experts is not None:
shared_output = self.shared_experts(x, reduce=False)
else:
shared_output = None
router_logits = self.gate(x)
out = self.moe_layer(x, gating_output=router_logits)
if shared_output is not None:
out = out + shared_output
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out.view(*x.shape)
class DeepseekV3Layer(nn.Module):
def __init__(self, prefix, layer_id, config, weights):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = DeepseekV3Attention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
)
if (
config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace
and layer_id % config.moe_layer_freq == 0
):
moe_layer_cls = (
SparseMoELayer
if SparseMoELayer.is_supported(weights)
else DenseMoELayer
)
self.mlp = DeepseekV3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
else:
self.mlp = DeepseekV3MLP(
prefix=f"{prefix}.mlp",
config=config,
weights=weights,
intermediate_size=config.intermediate_size,
)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache,
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
):
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
)
# faster post attention rms norm
normed_attn_res_output, residual = self.post_attention_layernorm(
attn_output, residual
)
output = self.mlp(normed_attn_res_output)
return output, residual
class DeepseekV3Model(torch.nn.Module):
def __init__(self, prefix: str, config, weights: Weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
DeepseekV3Layer(
prefix,
layer_id,
config,
weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
max_s,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashDeepseekV3ForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights: Weights):
super().__init__()
self.model = DeepseekV3Model(
"model" if not prefix else f"{prefix}.model", config, weights
)
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -642,9 +642,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
embedding_multiplier = getattr(config, "embedding_multiplier", None) embedding_multiplier = getattr(config, "embedding_multiplier", None)
if embedding_multiplier is not None: if embedding_multiplier is not None:
self.embed_tokens.weight.data *= embedding_multiplier self.embed_tokens.weight.data *= embedding_multiplier
prefix = suffix if not prefix or name != "model" else f"{prefix}.{suffix}"
prefix = "lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}"
with no_fp8(weights): with no_fp8(weights):
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,

View File

@ -32,6 +32,8 @@ from text_generation_server.layers.attention import (
Seqlen, Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
@ -109,17 +111,31 @@ class Starcoder2Config(PretrainedConfig):
) )
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights, layer_id):
prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
head_size = config.hidden_size // config.num_attention_heads
sizes = [
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
]
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) base_layer = _load_gqa(config, prefix, weights)
else: else:
return TensorParallelColumnLinear.load_multi( base_layer = TensorParallelColumnLinear.load_multi(
config, config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=prefixes,
dim=0, dim=0,
weights=weights, weights=weights,
bias=config.use_bias, bias=config.use_bias,
) )
return TensorParallelMultiAdapterLinear.load(
base_layer=base_layer,
layer_id=layer_id,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
def _load_gqa(config, prefix: str, weights): def _load_gqa(config, prefix: str, weights):
@ -157,6 +173,7 @@ def _load_gqa(config, prefix: str, weights):
class Starcoder2Attention(torch.nn.Module): class Starcoder2Attention(torch.nn.Module):
def __init__( def __init__(
self, self,
index: int,
prefix: str, prefix: str,
config, config,
weights, weights,
@ -188,15 +205,23 @@ class Starcoder2Attention(torch.nn.Module):
config.num_key_value_heads // weights.process_group.size() config.num_key_value_heads // weights.process_group.size()
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights, index)
self.kv_scales = get_kv_scales(weights, f"{prefix}") self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load( o_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
weights=weights, weights=weights,
bias=config.use_bias, bias=getattr(config, "use_bias", False),
) )
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
index,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange( self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
@ -214,8 +239,9 @@ class Starcoder2Attention(torch.nn.Module):
seqlen, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split( query, kv = qkv.split(
[ [
self.head_size * self.num_heads, self.head_size * self.num_heads,
@ -267,11 +293,13 @@ class Starcoder2Attention(torch.nn.Module):
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class Starcoder2MLP(nn.Module): class Starcoder2MLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights, index):
super().__init__() super().__init__()
act = config.hidden_act act = config.hidden_act
self.act = ( self.act = (
@ -285,27 +313,42 @@ class Starcoder2MLP(nn.Module):
) )
) )
# Fuse gate and up proj # Fuse gate and up proj
self.c_fc = TensorParallelColumnLinear.load( c_fc = TensorParallelColumnLinear.load(
config, config,
prefix=f"{prefix}.c_fc", prefix=f"{prefix}.c_fc",
weights=weights, weights=weights,
bias=config.use_bias, bias=config.use_bias,
) )
self.c_proj = TensorParallelRowLinear.load( c_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.c_proj", prefix=f"{prefix}.c_proj",
weights=weights, weights=weights,
bias=config.use_bias, bias=config.use_bias,
) )
def forward(self, hidden_states): self.c_fc = TensorParallelMultiAdapterLinear.load(
hidden_states = self.c_fc(hidden_states) c_fc,
layer_id=index,
layer_names=[f"{prefix}.c_fc"],
sizes=[config.intermediate_size, config.intermediate_size],
process_group=weights.process_group,
)
self.c_proj = TensorParallelAdapterRowLinear.load(
c_proj,
index,
"c_proj",
process_group=weights.process_group,
)
def forward(self, hidden_states, adapter_data):
hidden_states = self.c_fc(hidden_states, adapter_data)
hidden_states = self.act(hidden_states) hidden_states = self.act(hidden_states)
return self.c_proj(hidden_states) return self.c_proj(hidden_states, adapter_data)
class Starcoder2GatedMLP(nn.Module): class Starcoder2GatedMLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, index, prefix, config, weights):
super().__init__() super().__init__()
act = config.hidden_act act = config.hidden_act
self.act = ( self.act = (
@ -319,27 +362,47 @@ class Starcoder2GatedMLP(nn.Module):
) )
) )
# Fuse gate and up proj # Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi( prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
sizes = [
config.intermediate_size,
config.intermediate_size,
]
gate_up_proj = TensorParallelColumnLinear.load_multi(
config, config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], prefixes=prefixes,
weights=weights, weights=weights,
dim=0, dim=0,
bias=config.use_bias, bias=config.use_bias,
) )
self.down_proj = TensorParallelRowLinear.load( self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
index,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
weights=weights, weights=weights,
bias=config.use_bias, bias=config.use_bias,
) )
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
index,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = ( self.intermediate_size = (
config.intermediate_size // weights.process_group.size() config.intermediate_size // weights.process_group.size()
) )
def forward(self, hidden_states): def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states) gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
STARCODER2_NORMALIZATION_CLASSES = { STARCODER2_NORMALIZATION_CLASSES = {
@ -358,11 +421,11 @@ class Starcoder2Layer(nn.Module):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}" prefix = f"model.layers.{layer_id}"
self.self_attn = Starcoder2Attention( self.self_attn = Starcoder2Attention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights, index=layer_id
) )
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type]( self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
prefix=f"{prefix}.mlp", config=config, weights=weights prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id
) )
self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load( self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
@ -389,6 +452,7 @@ class Starcoder2Layer(nn.Module):
seqlen, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -404,6 +468,7 @@ class Starcoder2Layer(nn.Module):
seqlen, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
) )
# faster post attention rms norm # faster post attention rms norm
@ -411,7 +476,7 @@ class Starcoder2Layer(nn.Module):
attn_output, res attn_output, res
) )
mlp_output = self.mlp(normed_attn_res_output) mlp_output = self.mlp(normed_attn_res_output, adapter_data)
return mlp_output, attn_res return mlp_output, attn_res
@ -458,6 +523,7 @@ class Starcoder2Model(torch.nn.Module):
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
adapter_data,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
@ -481,6 +547,7 @@ class Starcoder2Model(torch.nn.Module):
seqlen, seqlen,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
@ -552,6 +619,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
max_s, max_s,
true_max_s, true_max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -1595,7 +1595,9 @@ class FlashCausalLM(Model):
if max_total_tokens is None: if max_total_tokens is None:
if get_support_chunking(): if get_support_chunking():
model_max_length = self.tokenizer.model_max_length model_max_length = self.tokenizer.model_max_length
max_position_embeddings = self.config.max_position_embeddings max_position_embeddings = getattr(
self.config, "max_position_embeddings", model_max_length
)
max_total_tokens = min( max_total_tokens = min(
num_blocks * BLOCK_SIZE, model_max_length, max_position_embeddings num_blocks * BLOCK_SIZE, model_max_length, max_position_embeddings
) )

View File

@ -14,26 +14,33 @@ PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
} }
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "1").lower() in {"1", "true"} PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "1").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
_expected = {"paged", "flashdecoding", "flashinfer"} _expected = {"paged", "flashdecoding", "flashdecoding-ipex", "flashinfer"}
assert ( assert (
ATTENTION in _expected ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}" ), f"Attention is not valid {ATTENTION}, expected {_expected}"
log_master(logger.info, f"Using Attention = {ATTENTION}") log_master(logger.info, f"Using Attention = {ATTENTION}")
if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}: if PREFIX_CACHING and ATTENTION not in {
"flashinfer",
"flashdecoding",
"flashdecoding-ipex",
}:
raise RuntimeError("Prefix caching is only supported with flashinfer") raise RuntimeError("Prefix caching is only supported with flashinfer")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95")) TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.93"))
assert TGI_WIGGLE_ROOM > 0 assert TGI_WIGGLE_ROOM > 0
assert TGI_WIGGLE_ROOM < 1 assert TGI_WIGGLE_ROOM < 1
# This is overridden by the cli # This is overridden by the cli
BLOCK_SIZE: int BLOCK_SIZE: int
if ATTENTION == "flashdecoding": if ATTENTION == "flashdecoding":
BLOCK_SIZE = 256 BLOCK_SIZE = 256
elif ATTENTION == "flashinfer": elif ATTENTION == "flashinfer":
BLOCK_SIZE = 1 BLOCK_SIZE = 1
elif ATTENTION == "flashdecoding-ipex":
BLOCK_SIZE = 64
else: else:
BLOCK_SIZE = 16 BLOCK_SIZE = 16

View File

@ -79,10 +79,13 @@ class Model(ABC):
"Prefill chunking will be turned off", "Prefill chunking will be turned off",
) )
support_chunking = False support_chunking = False
if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking: if (
ATTENTION not in ["flashinfer", "flashdecoding", "flashdecoding-ipex"]
and support_chunking
):
log_master( log_master(
logger.warning, logger.warning,
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.", "Prefill chunking is only supported with `flashinfer` or `flashdecoding` or `flashdecoding-ipex` attention types.",
) )
support_chunking = False support_chunking = False

View File

@ -0,0 +1,270 @@
import math
from typing import List, Optional
import torch
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers.modeling_utils
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.utils import initialize_torch_distributed
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
from text_generation_server.models.globals import ATTENTION
tracer = trace.get_tracer(__name__)
def tgi_flash_attention_forward(
module,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: Optional[torch.Tensor], # This is a positional arg in Transformers
kv_cache: List[KVCache],
kv_head_mapping: torch.Tensor,
slots: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
seqlen: Seqlen,
block_tables: torch.Tensor,
max_s: int,
kv_scales: KVScales,
softmax_scale: Optional[float] = None,
sliding_window: Optional[int] = None,
softcap: Optional[float] = None,
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
):
kv_cache = kv_cache[module.layer_idx]
query_states = query_states.transpose(1, 2).squeeze(dim=0)
key_states = key_states.transpose(1, 2).squeeze(dim=0)
value_states = value_states.transpose(1, 2).squeeze(dim=0)
# Take care of updating the cache in-place
kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales)
_, num_heads, head_dim = query_states.shape
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
sliding_window = -1 if sliding_window is None else sliding_window
if cu_seqlen_prefill is not None:
attn_output = attention(
query=query_states,
key=key_states,
value=value_states,
kv_cache=kv_cache,
kv_scales=kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=softmax_scale,
window_size_left=sliding_window,
softcap=softcap,
)
else:
attn_output = paged_attention(
query_states,
kv_cache,
kv_head_mapping,
softmax_scale,
block_tables,
seqlen,
max_s,
kv_scales=kv_scales,
softcap=softcap,
)
attn_output = attn_output.view(-1, num_heads * head_dim)
return attn_output, None
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
class TransformersFlashCausalLM(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
default_dtype=torch.float16,
trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer,
kv_cache_dtype: Optional[torch.dtype] = None,
):
self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed()
if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device("xpu")
dtype = default_dtype if dtype is None else dtype
else:
raise ValueError(
"Flash `Transformers` modeling backend is not available on cpu."
)
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
attn_implementation="tgi",
device_map=device if world_size == 1 else None,
tp_plan="auto" if world_size > 1 else None,
)
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None and isinstance(
model.config.eos_token_id, int
):
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.num_layers = model.config.num_hidden_layers
self.num_heads = model.config.num_attention_heads // self.process_group.size()
self.num_kv_heads = model.config.num_key_value_heads
self.num_kv_heads = (
self.num_kv_heads // self.process_group.size()
if self.num_kv_heads > 1
else self.num_kv_heads
)
self.head_size = model.config.hidden_size // model.config.num_attention_heads
self.cuda_graphs = {}
self.kv_cache = []
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import (
create_prefill_state,
create_decode_state,
create_prefill_with_paged_kv_state,
)
self.prefill_state = create_prefill_state(device=device)
self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
device=device
)
self.decode_state = create_decode_state(
device=device,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
)
self.num_groups = self.num_heads // self.num_kv_heads
# Those will never change and will be used in the forwards
self.kv_head_mapping = torch.arange(
0, self.num_kv_heads, dtype=torch.int32, device=device
).repeat_interleave(self.num_groups)
# This means no scale
self.kv_scales = KVScales(
torch.tensor(1.0, device=device),
torch.tensor(1.0, device=device),
)
torch.distributed.barrier(group=self.process_group)
# Skip FlashCausalLM init.
super(FlashCausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
# Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code
# We first copy the original model.forward because we still need it in the monkey patch
self.model.original_forward = self.model.forward
self.model.forward = self._model_forward
@classmethod
def fallback(
cls,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
return cls(
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def _model_forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[KVCache],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
lm_head_indices: Optional[torch.Tensor],
prefill_cache_indices=None, # not used, but passed to match original signature
adapter_data=None, # not supported, but passed to match original signature
):
hidden_states = self.model.model.forward(
input_ids=input_ids.unsqueeze(0), # expand dim to fit Transformers
position_ids=position_ids.unsqueeze(0), # expand dim to fit Transformers
past_key_values=None, # we use self.kv_cache instead of transformers cache object
use_cache=False, # we use self.kv_cache instead of transformers cache object
return_dict=True,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
kv_head_mapping=self.kv_head_mapping,
kv_scales=self.kv_scales,
)[0].squeeze(dim=0)
# And compute logits from the lm_head, slicing correctly the indices
# NOTE: some logits post-processing (e.g. in gemma2) may be absent here with the split of the modules
# To update with full Transformers support asap
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.model.lm_head(hidden_states)
# For Granite while next transformers version is released and we can use `lm_head_indices` natively
if hasattr(self.model.config, "logits_scaling"):
logits = logits / self.model.config.logits_scaling
# For Cohere for similar reasons
elif hasattr(self.model, "logit_scale"):
logits = logits * self.model.logit_scale
return logits, None

View File

@ -68,9 +68,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
self.quantize = model.quantize self.quantize = model.quantize
self.server_urls = server_urls self.server_urls = server_urls
# For some reason, inference_mode does not work well with GLOO which we use on CPU # For some reason, inference_mode does not work well with GLOO which we use on CPU
if model.device.type == "cuda": # if model.device.type == "cuda":
# Force inference mode for the lifetime of TextGenerationService # # Force inference mode for the lifetime of TextGenerationService
self._inference_mode_raii_guard = torch._C._InferenceMode(True) # self._inference_mode_raii_guard = torch._C._InferenceMode(True)
async def Info(self, request, context): async def Info(self, request, context):
return self.model.info return self.model.info

View File

@ -281,6 +281,12 @@ def get_mlp_weights(i, layer):
if hasattr(mlp, "up_proj"): if hasattr(mlp, "up_proj"):
weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj) weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj)
if hasattr(mlp, "c_fc"):
weights[(i, "c_fc")] = (f"model.layers.{i}.mlp.c_fc", mlp.c_fc)
if hasattr(mlp, "c_proj"):
weights[(i, "c_proj")] = (f"model.layers.{i}.mlp.c_proj", mlp.c_proj)
if hasattr(mlp, "down_proj"): if hasattr(mlp, "down_proj"):
weights[(i, "down_proj")] = ( weights[(i, "down_proj")] = (
f"model.layers.{i}.mlp.down_proj", f"model.layers.{i}.mlp.down_proj",

View File

@ -81,12 +81,14 @@ def initialize_torch_distributed():
pg_options=options, pg_options=options,
) )
else: else:
device = torch.device(f"cuda:{RANK}")
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=backend, backend=backend,
world_size=WORLD_SIZE, world_size=WORLD_SIZE,
rank=RANK, rank=RANK,
timeout=timedelta(seconds=120), timeout=timedelta(seconds=120),
pg_options=options, pg_options=options,
device_id=device,
) )
else: else:
logger.warning("torch.distributed is already initialized.") logger.warning("torch.distributed is already initialized.")

View File

@ -5,13 +5,12 @@ import torch
from typing import List, Optional, DefaultDict from typing import List, Optional, DefaultDict
from loguru import logger from loguru import logger
from typing import Dict, Union from typing import Dict
from text_generation_server.pb.generate_pb2 import GrammarType from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.guide import RegexGuide from outlines.fsm.guide import RegexGuide
from transformers import ( from transformers import (
LogitsWarper,
LogitsProcessor, LogitsProcessor,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
TemperatureLogitsWarper, TemperatureLogitsWarper,
@ -219,7 +218,7 @@ class HeterogeneousTemperatureLogitsWarper:
return None return None
class HeterogeneousTopPLogitsWarper(LogitsWarper): class HeterogeneousTopPLogitsWarper(LogitsProcessor):
""" """
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
This version allows for a separate value for each sample and runs inplace when possible. This version allows for a separate value for each sample and runs inplace when possible.
@ -278,7 +277,7 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
return None return None
class HeterogeneousTopKLogitsWarper(LogitsWarper): class HeterogeneousTopKLogitsWarper(LogitsProcessor):
r""" r"""
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
This version allows for a separate value for each sample and runs inplace when possible. This version allows for a separate value for each sample and runs inplace when possible.
@ -359,7 +358,7 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
return None return None
class HeterogeneousTypicalLogitsWarper(LogitsWarper): class HeterogeneousTypicalLogitsWarper(LogitsProcessor):
r""" r"""
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language [`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
Generation](https://arxiv.org/abs/2202.00666) for more information. Generation](https://arxiv.org/abs/2202.00666) for more information.
@ -453,13 +452,13 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
r""" r"""
A wrapper for logit warpers or processors without heterogeneous parameter support. A wrapper for logit warpers or processors without heterogeneous parameter support.
Args: Args:
processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`): processors (`Dict[int, LogitsProcessor]`):
A mapping of sample indices to logit warpers or processors, to be run sequentially. A mapping of sample indices to logit warpers or processors, to be run sequentially.
""" """
def __init__( def __init__(
self, self,
processors: Dict[int, Union[LogitsProcessor, LogitsWarper]], processors: Dict[int, LogitsProcessor],
): ):
self.processors = processors self.processors = processors

View File

@ -1,7 +1,7 @@
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional, List
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from text_generation_server.layers.marlin.gptq import can_use_gptq_marlin from text_generation_server.layers.marlin.gptq import can_use_gptq_marlin
@ -20,6 +20,7 @@ class _QuantizerConfig:
groupsize: int groupsize: int
quant_method: str quant_method: str
sym: bool sym: bool
weight_block_size: Optional[List[int]]
@dataclass @dataclass
@ -49,16 +50,17 @@ def _get_quantizer_config(model_id, revision):
checkpoint_format = None checkpoint_format = None
sym = False sym = False
desc_act = False desc_act = False
weight_block_size = None
filename = "config.json" filename = "config.json"
try: try:
data = _get_config_json(model_id, revision, filename) data = _get_config_json(model_id, revision, filename)
# FP8 config # FP8 config
if data["quantization_config"]["quant_method"] == "fbgemm_fp8": if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
return _FP8QuantizerConfig( return _FP8QuantizerConfig(
activation_scale_ub=data["quantization_config"]["activation_scale_ub"] activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
) )
weight_block_size = data["quantization_config"].get("weight_block_size", None)
if "zero_point" in data["quantization_config"]: if "zero_point" in data["quantization_config"]:
sym = not data["quantization_config"]["zero_point"] sym = not data["quantization_config"]["zero_point"]
@ -107,6 +109,7 @@ def _get_quantizer_config(model_id, revision):
checkpoint_format=checkpoint_format, checkpoint_format=checkpoint_format,
sym=sym, sym=sym,
desc_act=desc_act, desc_act=desc_act,
weight_block_size=weight_block_size,
) )
@ -196,9 +199,14 @@ def get_loader(
# Since the default for the quantize config is _QuantizerConfig, # Since the default for the quantize config is _QuantizerConfig,
# we need to add this check to not get an attribute error # we need to add this check to not get an attribute error
activation_scale_ub = None activation_scale_ub = None
weight_block_size = quantizer_config.weight_block_size
if isinstance(quantizer_config, _FP8QuantizerConfig): if isinstance(quantizer_config, _FP8QuantizerConfig):
activation_scale_ub = quantizer_config.activation_scale_ub activation_scale_ub = quantizer_config.activation_scale_ub
return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8") return HybridFP8UnquantLoader(
activation_scale_ub,
to_fp8=quantize == "fp8",
weight_block_size=weight_block_size,
)
else: else:
raise ValueError(f"Unknown quantization method: {quantize}") raise ValueError(f"Unknown quantization method: {quantize}")

3268
server/uv.lock Normal file

File diff suppressed because it is too large Load Diff