Merge remote-tracking branch 'origin/main' into fix-issue-2864

This commit is contained in:
Nicolas Casademont 2025-02-10 11:15:04 +01:00
commit ab92e153e1
No known key found for this signature in database
GPG Key ID: 6DFD8231DE0D1AC9
122 changed files with 9537 additions and 5510 deletions

View File

@ -1,75 +0,0 @@
ARG CUDA_ARCH_LIST="75-real;80-real;86-real;89-real;90-real"
ARG OMPI_VERSION="4.1.7rc1"
# Build dependencies resolver stage
FROM lukemathwalker/cargo-chef:latest AS chef
WORKDIR /usr/src/text-generation-inference/backends/trtllm
FROM chef AS planner
COPY . .
RUN cargo chef prepare --recipe-path recipe.json
# CUDA dependent dependencies resolver stage
FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu24.04 AS cuda-builder
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt update && apt install -y \
build-essential \
cmake \
curl \
gcc-14 \
g++-14 \
git \
git-lfs \
libssl-dev \
libucx-dev \
ninja-build \
pkg-config \
pipx \
python3 \
python3-dev \
python3-setuptools \
tar \
wget && \
pipx ensurepath
ENV TGI_INSTALL_PREFIX=/usr/local/tgi
ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt
# Install OpenMPI
FROM cuda-builder AS mpi-builder
ARG OMPI_VERSION
ENV OMPI_TARBALL_FILENAME="openmpi-$OMPI_VERSION.tar.bz2"
RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILENAME" -P /opt/src && \
mkdir /usr/src/mpi && \
tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \
cd /usr/src/mpi && \
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \
make -j all && \
make install && \
rm -rf "/opt/src/$OMPI_TARBALL_FILENAME"
# Install TensorRT
FROM cuda-builder AS trt-builder
COPY backends/trtllm/scripts/install_tensorrt.sh /opt/install_tensorrt.sh
RUN chmod +x /opt/install_tensorrt.sh && \
/opt/install_tensorrt.sh
# Build Backend
FROM cuda-builder AS tgi-builder
WORKDIR /usr/src/text-generation-inference
# Install Rust
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \
chmod -R a+w /root/.rustup && \
chmod -R a+w /root/.cargo
ENV PATH="/root/.cargo/bin:$PATH"
RUN cargo install cargo-chef
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
ENV MPI_HOME=/usr/local/mpi

View File

@ -1,19 +0,0 @@
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/cpp
{
"name": "CUDA",
"build": {
"dockerfile": "Dockerfile_trtllm",
"context": ".."
},
"remoteEnv": {
"PATH": "${containerEnv:PATH}:/usr/local/cuda/bin",
"LD_LIBRARY_PATH": "$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64",
"XLA_FLAGS": "--xla_gpu_cuda_data_dir=/usr/local/cuda"
},
"customizations" : {
"jetbrains" : {
"backend" : "CLion"
}
}
}

View File

@ -31,16 +31,28 @@ jobs:
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }} group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true cancel-in-progress: true
runs-on: runs-on:
group: aws-highmemory-32-plus-priv group: aws-highmemory-64-plus-priv
permissions: permissions:
contents: write contents: write
packages: write packages: write
id-token: write
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Inject slug/short variables - name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1 uses: rlespinasse/github-slug-action@v4.4.1
- name: Construct harware variables - name: Inject required variables for sccache to interact with Github Actions Cache
uses: actions/github-script@v7
with:
script: |
core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || '');
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
- name: Extract TensorRT-LLM version
run: |
echo "TENSORRT_LLM_VERSION=$(grep -oP '([a-z,0-9]{40})' $GITHUB_WORKSPACE/backends/trtllm/cmake/trtllm.cmake)" >> $GITHUB_ENV
echo "TensorRT-LLM version: ${{ env.TENSORRT_LLM_VERSION }}"
- name: Construct hardware variables
shell: bash shell: bash
run: | run: |
case ${{ inputs.hardware }} in case ${{ inputs.hardware }} in
@ -52,6 +64,7 @@ jobs:
export runs_on="aws-g6-12xl-plus-priv-cache" export runs_on="aws-g6-12xl-plus-priv-cache"
export platform="" export platform=""
export extra_pytest="" export extra_pytest=""
export target=""
;; ;;
cuda-trtllm) cuda-trtllm)
export dockerfile="Dockerfile_trtllm" export dockerfile="Dockerfile_trtllm"
@ -61,15 +74,24 @@ jobs:
export runs_on="ubuntu-latest" export runs_on="ubuntu-latest"
export platform="" export platform=""
export extra_pytest="" export extra_pytest=""
if [[ "${GITHUB_REF}" == refs/tags/* ]]; then
export build_type="release";
export target="";
else
export build_type="dev";
export target="ci-runtime";
fi
;; ;;
rocm) rocm)
export dockerfile="Dockerfile_amd" export dockerfile="Dockerfile_amd"
export label_extension="-rocm" export label_extension="-rocm"
export docker_devices="/dev/kfd,/dev/dri" export docker_devices="/dev/kfd,/dev/dri"
export docker_volume="/mnt" export docker_volume="/mnt"
export runs_on="amd-gpu-runners" # This runner was deactivated.
export runs_on="ubuntu-latest"
export platform="" export platform=""
export extra_pytest="-k test_flash_gemma_gptq_load" export extra_pytest="-k test_flash_gemma_gptq_load"
export target=""
;; ;;
intel-xpu) intel-xpu)
export dockerfile="Dockerfile_intel" export dockerfile="Dockerfile_intel"
@ -79,6 +101,7 @@ jobs:
export runs_on="ubuntu-latest" export runs_on="ubuntu-latest"
export platform="xpu" export platform="xpu"
export extra_pytest="" export extra_pytest=""
export target=""
;; ;;
intel-cpu) intel-cpu)
export dockerfile="Dockerfile_intel" export dockerfile="Dockerfile_intel"
@ -89,6 +112,7 @@ jobs:
export runs_on="aws-highmemory-32-plus-priv" export runs_on="aws-highmemory-32-plus-priv"
export platform="cpu" export platform="cpu"
export extra_pytest="-k test_flash_gemma_simple" export extra_pytest="-k test_flash_gemma_simple"
export target=""
;; ;;
esac esac
echo $dockerfile echo $dockerfile
@ -105,6 +129,8 @@ jobs:
echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV
echo "EXTRA_PYTEST=${extra_pytest}" >> $GITHUB_ENV echo "EXTRA_PYTEST=${extra_pytest}" >> $GITHUB_ENV
echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV
echo "TARGET=${target}" >> $GITHUB_ENV
echo "BUILD_TYPE=${build_type}" >> $GITHUB_ENV
- name: Initialize Docker Buildx - name: Initialize Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
with: with:
@ -169,9 +195,14 @@ jobs:
GIT_SHA=${{ env.GITHUB_SHA }} GIT_SHA=${{ env.GITHUB_SHA }}
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
PLATFORM=${{ env.PLATFORM }} PLATFORM=${{ env.PLATFORM }}
build_type=${{ env.BUILD_TYPE }}
sccache_gha_enabled=on
actions_cache_url=${{ env.ACTIONS_CACHE_URL }}
actions_runtime_token=${{ env.ACTIONS_RUNTIME_TOKEN }}
target: ${{ env.TARGET }}
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }} tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }} labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=max,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=max
cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
- name: Final - name: Final
id: final id: final
@ -214,3 +245,23 @@ jobs:
echo $DOCKER_IMAGE echo $DOCKER_IMAGE
docker pull $DOCKER_IMAGE docker pull $DOCKER_IMAGE
pytest -s -vv integration-tests ${PYTEST_FLAGS} ${EXTRA_PYTEST} pytest -s -vv integration-tests ${PYTEST_FLAGS} ${EXTRA_PYTEST}
backend_trtllm_cxx_tests:
needs: build-and-push
if: needs.build-and-push.outputs.label == '-trtllm'
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-trtllm-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
runs-on:
group: aws-g6-12xl-plus-priv-cache
container:
image: ${{ needs.build-and-push.outputs.docker_image }}
credentials:
username: ${{ secrets.REGISTRY_USERNAME }}
password: ${{ secrets.REGISTRY_PASSWORD }}
options: --gpus all --shm-size=8g
steps:
- name: Run C++/CUDA tests
if: ${{ env.LABEL == 'ci-runtime' }}
run: /usr/local/tgi/bin/tgi_trtllm_backend_tests

View File

@ -42,6 +42,7 @@ jobs:
permissions: permissions:
contents: write contents: write
packages: write packages: write
id-token: write
with: with:
hardware: ${{ matrix.hardware }} hardware: ${{ matrix.hardware }}
# https://github.com/actions/runner/issues/2206 # https://github.com/actions/runner/issues/2206

View File

@ -31,7 +31,7 @@ jobs:
with: with:
# Released on: 02 May, 2024 # Released on: 02 May, 2024
# https://releases.rs/docs/1.78.0/ # https://releases.rs/docs/1.78.0/
toolchain: 1.80.0 toolchain: 1.84.0
override: true override: true
components: rustfmt, clippy components: rustfmt, clippy
- name: Install Protoc - name: Install Protoc
@ -44,10 +44,14 @@ jobs:
run: | run: |
sudo apt update sudo apt update
sudo apt install python3.11-dev -y sudo apt install python3.11-dev -y
pip install -U pip uv
uv venv
source ./.venv/bin/activate
make install-cpu make install-cpu
- name: Run server tests - name: Run server tests
run: | run: |
pip install pytest source ./.venv/bin/activate
uv pip install pytest
export HF_TOKEN=${{ secrets.HF_TOKEN }} export HF_TOKEN=${{ secrets.HF_TOKEN }}
pytest -s -vv server/tests pytest -s -vv server/tests
- name: Pre-commit checks - name: Pre-commit checks

View File

@ -16,3 +16,5 @@ jobs:
fetch-depth: 0 fetch-depth: 0
- name: Secret Scanning - name: Secret Scanning
uses: trufflesecurity/trufflehog@main uses: trufflesecurity/trufflehog@main
with:
extra_args: --results=verified,unknown

88
Cargo.lock generated
View File

@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo. # This file is automatically @generated by Cargo.
# It is not intended for manual editing. # It is not intended for manual editing.
version = 3 version = 4
[[package]] [[package]]
name = "addr2line" name = "addr2line"
@ -52,6 +52,21 @@ version = "0.2.20"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9" checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9"
[[package]]
name = "android-tzdata"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
[[package]]
name = "android_system_properties"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "anstream" name = "anstream"
version = "0.6.18" version = "0.6.18"
@ -651,6 +666,20 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "chrono"
version = "0.4.39"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825"
dependencies = [
"android-tzdata",
"iana-time-zone",
"js-sys",
"num-traits",
"wasm-bindgen",
"windows-targets 0.52.6",
]
[[package]] [[package]]
name = "clang-sys" name = "clang-sys"
version = "1.8.1" version = "1.8.1"
@ -1544,7 +1573,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
dependencies = [ dependencies = [
"ahash", "ahash",
"allocator-api2",
] ]
[[package]] [[package]]
@ -1802,6 +1830,29 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "iana-time-zone"
version = "0.1.61"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220"
dependencies = [
"android_system_properties",
"core-foundation-sys",
"iana-time-zone-haiku",
"js-sys",
"wasm-bindgen",
"windows-core",
]
[[package]]
name = "iana-time-zone-haiku"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
dependencies = [
"cc",
]
[[package]] [[package]]
name = "icu_collections" name = "icu_collections"
version = "1.5.0" version = "1.5.0"
@ -2187,9 +2238,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.164" version = "0.2.169"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a"
[[package]] [[package]]
name = "libfuzzer-sys" name = "libfuzzer-sys"
@ -4424,14 +4475,14 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-backends-trtllm" name = "text-generation-backends-trtllm"
version = "3.0.2-dev0" version = "3.1.1-dev0"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"clap 4.5.21", "clap 4.5.21",
"cmake", "cmake",
"cxx", "cxx",
"cxx-build", "cxx-build",
"hashbrown 0.14.5", "hashbrown 0.15.1",
"hf-hub", "hf-hub",
"pkg-config", "pkg-config",
"pyo3", "pyo3",
@ -4445,7 +4496,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-benchmark" name = "text-generation-benchmark"
version = "3.0.2-dev0" version = "3.1.1-dev0"
dependencies = [ dependencies = [
"average", "average",
"clap 4.5.21", "clap 4.5.21",
@ -4465,7 +4516,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "3.0.2-dev0" version = "3.1.1-dev0"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"base64 0.22.1", "base64 0.22.1",
@ -4483,7 +4534,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "3.0.2-dev0" version = "3.1.1-dev0"
dependencies = [ dependencies = [
"clap 4.5.21", "clap 4.5.21",
"ctrlc", "ctrlc",
@ -4504,7 +4555,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router" name = "text-generation-router"
version = "3.0.2-dev0" version = "3.1.1-dev0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-stream", "async-stream",
@ -4512,6 +4563,7 @@ dependencies = [
"axum 0.7.9", "axum 0.7.9",
"axum-tracing-opentelemetry", "axum-tracing-opentelemetry",
"base64 0.22.1", "base64 0.22.1",
"chrono",
"clap 4.5.21", "clap 4.5.21",
"csv", "csv",
"futures", "futures",
@ -4555,7 +4607,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router-v2" name = "text-generation-router-v2"
version = "3.0.2-dev0" version = "3.1.1-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
@ -4604,7 +4656,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router-v3" name = "text-generation-router-v3"
version = "3.0.2-dev0" version = "3.1.1-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
@ -4791,9 +4843,9 @@ dependencies = [
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.41.1" version = "1.43.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33" checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e"
dependencies = [ dependencies = [
"backtrace", "backtrace",
"bytes", "bytes",
@ -4819,9 +4871,9 @@ dependencies = [
[[package]] [[package]]
name = "tokio-macros" name = "tokio-macros"
version = "2.4.0" version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -4862,9 +4914,9 @@ dependencies = [
[[package]] [[package]]
name = "tokio-stream" name = "tokio-stream"
version = "0.1.16" version = "0.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f4e6ce100d0eb49a2734f8c0812bcd324cf357d21810932c5df6b96ef2b86f1" checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"pin-project-lite", "pin-project-lite",

View File

@ -20,7 +20,7 @@ default-members = [
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "3.0.2-dev0" version = "3.1.1-dev0"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"] authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference" homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -1,5 +1,5 @@
# Rust builder # Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.84.0 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -47,7 +47,7 @@ RUN cargo build --profile release-opt --frozen
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099 # NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
ARG PYTORCH_VERSION=2.4.0 ARG PYTORCH_VERSION=2.5.1
ARG PYTHON_VERSION=3.11 ARG PYTHON_VERSION=3.11
# Keep in sync with `server/pyproject.toml # Keep in sync with `server/pyproject.toml
@ -58,7 +58,7 @@ ARG INSTALL_CHANNEL=pytorch
# Automatically set by buildx # Automatically set by buildx
ARG TARGETPLATFORM ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH ENV PATH=/opt/conda/bin:$PATH
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \ build-essential \
@ -195,48 +195,56 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
git \ git \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Install flash-attention dependencies
# RUN pip install einops --no-cache-dir
# Copy conda with PyTorch installed # Copy conda with PyTorch installed
COPY --from=pytorch-install /opt/conda /opt/conda COPY --from=pytorch-install /opt/conda /opt/conda
# Copy build artifacts from flash attention builder
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /opt/conda/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from exllama kernels builder
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from awq kernels builder
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from eetq kernels builder
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from lorax punica kernels builder
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from mamba builder
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages
COPY --from=flashinfer-builder /opt/conda/lib/python3.11/site-packages/flashinfer/ /opt/conda/lib/python3.11/site-packages/flashinfer/
# Install flash-attention dependencies
RUN pip install einops --no-cache-dir
# Install server # Install server
COPY proto proto COPY proto proto
COPY server server COPY server server
COPY server/Makefile server/Makefile COPY server/Makefile server/Makefile
ENV UV_SYSTEM_PYTHON=1
RUN cd server && \ RUN cd server && \
make gen-server && \ pip install -U pip uv && \
pip install -r requirements_cuda.txt && \ uv sync --frozen --extra gen --extra attention --extra bnb --extra accelerate --extra compressed-tensors --extra marlin --extra moe --extra quantize --extra peft --extra outlines --no-install-project && \
pip install ".[attention, bnb, accelerate, compressed-tensors, marlin, moe, quantize, peft, outlines]" --no-cache-dir && \ . ./.venv/bin/activate && \
pip install nvidia-nccl-cu12==2.22.3 make gen-server-raw
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2 RUN cd server && \
uv sync --frozen --extra gen --extra attention --extra bnb --extra accelerate --extra compressed-tensors --extra marlin --extra moe --extra quantize --extra peft --extra outlines && \
. ./.venv/bin/activate && \
pwd && \
text-generation-server --help
# Copy build artifacts from flash attention builder
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /usr/src/server/.venv/lib/python3.11/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /usr/src/server/.venv/lib/python3.11/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-311 /usr/src/server/.venv/lib/python3.11/site-packages
# Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /opt/conda/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so /usr/src/server/.venv/lib/python3.11/site-packages
# Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /usr/src/server/.venv/lib/python3.11/site-packages
# Copy build artifacts from exllama kernels builder
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /usr/src/server/.venv/lib/python3.11/site-packages
# Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /usr/src/server/.venv/lib/python3.11/site-packages
# Copy build artifacts from awq kernels builder
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /usr/src/server/.venv/lib/python3.11/site-packages
# Copy build artifacts from eetq kernels builder
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-311 /usr/src/server/.venv/lib/python3.11/site-packages
# Copy build artifacts from lorax punica kernels builder
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /usr/src/server/.venv/lib/python3.11/site-packages
# Copy build artifacts from mamba builder
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /usr/src/server/.venv/lib/python3.11/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /usr/src/server/.venv/lib/python3.11/site-packages
COPY --from=flashinfer-builder /opt/conda/lib/python3.11/site-packages/flashinfer/ /usr/src/server/.venv/lib/python3.11/site-packages/flashinfer/
# ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
# Required to find libpython within the rust binaries # Required to find libpython within the rust binaries
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/" ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
# This is needed because exl2 tries to load flash-attn # This is needed because exl2 tries to load flash-attn

View File

@ -1,5 +1,5 @@
# Rust builder # Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.84.0 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -104,7 +104,7 @@ RUN case ${TARGETPLATFORM} in \
/opt/conda/bin/conda clean -ya /opt/conda/bin/conda clean -ya
# Install flash-attention, torch dependencies # Install flash-attention, torch dependencies
RUN python3 -m pip install --upgrade pip && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/* RUN python3 -m pip install --upgrade pip uv && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/*
RUN conda install mkl=2021 RUN conda install mkl=2021
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/
@ -279,7 +279,7 @@ RUN git clone https://github.com/danieldk/marlin-kernels.git && \
FROM kernel-builder AS moe-kernels FROM kernel-builder AS moe-kernels
WORKDIR /usr/src WORKDIR /usr/src
ENV MOE_KERNELS_BRANCH=a67b35841774b2056a73806c36661134b5054edd ENV MOE_KERNELS_BRANCH=v0.8.2
ENV VLLM_TARGET_DEVICE=rocm ENV VLLM_TARGET_DEVICE=rocm
RUN git clone https://github.com/danieldk/moe-kernels.git && \ RUN git clone https://github.com/danieldk/moe-kernels.git && \
cd moe-kernels && \ cd moe-kernels && \
@ -318,10 +318,18 @@ COPY --from=moe-kernels /usr/src/moe-kernels/build/lib.linux-x86_64-cpython-311
COPY proto proto COPY proto proto
COPY server server COPY server server
COPY server/Makefile server/Makefile COPY server/Makefile server/Makefile
ENV UV_SYSTEM_PYTHON=1
RUN cd server && \ RUN cd server && \
make gen-server && \ pip install -U pip uv && \
pip install -r requirements_rocm.txt && \ uv sync --frozen --extra gen --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --no-install-project && \
pip install ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir . ./.venv/bin/activate && \
make gen-server-raw
RUN cd server && \
uv sync --frozen --extra gen --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines && \
. ./.venv/bin/activate && \
pwd && \
text-generation-server --help
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark

View File

@ -1,6 +1,6 @@
ARG PLATFORM=xpu ARG PLATFORM=xpu
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.84.0 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -108,17 +108,19 @@ RUN pip install triton-xpu==3.0.0b2 --no-cache-dir
COPY proto proto COPY proto proto
COPY server server COPY server server
COPY server/Makefile server/Makefile COPY server/Makefile server/Makefile
ENV UV_SYSTEM_PYTHON=1
RUN cd server && \ RUN cd server && \
make gen-server && \ make gen-server && \
pip install -r requirements_intel.txt && \ pip install -U pip uv && \
pip install ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/pti/0.9/lib:/opt/conda/lib ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/pti/0.9/lib:/opt/conda/lib
ENV CCL_ZE_IPC_EXCHANGE=sockets ENV CCL_ZE_IPC_EXCHANGE=sockets
#ENV TORCH_LLM_ALLREDUCE=1 #ENV TORCH_LLM_ALLREDUCE=1
#ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0 #ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 033af6f63745ac748cccdadee5c6140c7971edf6 RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 1ccf72b2d11cd00b47aef6d6cd054c088aa6f083
RUN cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc,ats-m150' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch RUN cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc,ats-m150' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
# Install benchmarker # Install benchmarker
@ -188,7 +190,7 @@ RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.5.0.dev2024
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.20.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.20.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install triton py-libnuma RUN pip install triton==3.1.0 py-libnuma
WORKDIR /usr/src WORKDIR /usr/src
@ -211,10 +213,11 @@ ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
COPY proto proto COPY proto proto
COPY server server COPY server server
COPY server/Makefile server/Makefile COPY server/Makefile server/Makefile
ENV UV_SYSTEM_PYTHON=1
RUN cd server && \ RUN cd server && \
make gen-server && \ make gen-server && \
pip install -r requirements_intel.txt && \ pip install -U pip uv && \
pip install ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
@ -224,9 +227,9 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
FROM ${PLATFORM} AS final FROM ${PLATFORM} AS final
ENV ATTENTION=paged ENV ATTENTION=flashdecoding-ipex
ENV PREFIX_CACHING=0 ENV PREFIX_CACHING=1
ENV PREFILL_CHUNKING=0 ENV PREFILL_CHUNKING=1
ENV CUDA_GRAPHS=0 ENV CUDA_GRAPHS=0
ENTRYPOINT ["text-generation-launcher"] ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"] CMD ["--json-output"]

View File

@ -1,20 +1,16 @@
ARG CUDA_ARCH_LIST="75-real;80-real;86-real;89-real;90-real" ARG cuda_arch_list="75-real;80-real;86-real;89-real;90-real;100-real;120-real"
ARG OMPI_VERSION="4.1.7rc1" ARG cuda_base=12.8.0
ARG build_type=release
ARG ompi_version=4.1.7
ARG sccache_gha_enabled=off
ARG actions_cache_url=""
ARG actions_runtime_token=""
# Build dependencies resolver stage
FROM lukemathwalker/cargo-chef:latest AS chef
WORKDIR /usr/src/text-generation-inference/backends/trtllm
FROM chef AS planner
COPY . .
RUN cargo chef prepare --recipe-path recipe.json
# CUDA dependent dependencies resolver stage # CUDA dependent dependencies resolver stage
FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu24.04 AS cuda-builder FROM nvidia/cuda:${cuda_base}-cudnn-devel-ubuntu24.04 AS cuda-builder
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt update && apt install -y \
build-essential \ build-essential \
cmake \ cmake \
curl \ curl \
@ -22,8 +18,11 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
g++-14 \ g++-14 \
git \ git \
git-lfs \ git-lfs \
lld \
libssl-dev \ libssl-dev \
libucx-dev \ libucx-dev \
libasan8 \
libubsan1 \
ninja-build \ ninja-build \
pkg-config \ pkg-config \
pipx \ pipx \
@ -31,7 +30,7 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
python3-dev \ python3-dev \
python3-setuptools \ python3-setuptools \
tar \ tar \
wget && \ wget --no-install-recommends && \
pipx ensurepath pipx ensurepath
ENV TGI_INSTALL_PREFIX=/usr/local/tgi ENV TGI_INSTALL_PREFIX=/usr/local/tgi
@ -39,17 +38,19 @@ ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt
# Install OpenMPI # Install OpenMPI
FROM cuda-builder AS mpi-builder FROM cuda-builder AS mpi-builder
ARG OMPI_VERSION WORKDIR /opt/src/mpi
ENV OMPI_TARBALL_FILENAME="openmpi-$OMPI_VERSION.tar.bz2" ARG ompi_version
RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILENAME" -P /opt/src && \ ENV OMPI_VERSION=${ompi_version}
mkdir /usr/src/mpi && \ ENV OMPI_TARBALL_FILENAME=openmpi-${OMPI_VERSION}.tar.bz2
tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \ ADD --checksum=sha256:54a33cb7ad81ff0976f15a6cc8003c3922f0f3d8ceed14e1813ef3603f22cd34 \
cd /usr/src/mpi && \ https://download.open-mpi.org/release/open-mpi/v4.1/${OMPI_TARBALL_FILENAME} .
RUN tar --strip-components=1 -xf ${OMPI_TARBALL_FILENAME} &&\
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \ ./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \
make -j all && \ make -j all && \
make install && \ make install && \
rm -rf "/opt/src/$OMPI_TARBALL_FILENAME" rm -rf ${OMPI_TARBALL_FILENAME}/..
# Install TensorRT # Install TensorRT
FROM cuda-builder AS trt-builder FROM cuda-builder AS trt-builder
@ -61,32 +62,54 @@ RUN chmod +x /opt/install_tensorrt.sh && \
FROM cuda-builder AS tgi-builder FROM cuda-builder AS tgi-builder
WORKDIR /usr/src/text-generation-inference WORKDIR /usr/src/text-generation-inference
# Scoped global args reuse
ARG cuda_arch_list
ARG build_type
ARG sccache_gha_enabled
ARG actions_cache_url
ARG actions_runtime_token
# Install Rust # Install Rust
ENV PATH="/root/.cargo/bin:$PATH"
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \ RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \
chmod -R a+w /root/.rustup && \ chmod -R a+w /root/.rustup && \
chmod -R a+w /root/.cargo chmod -R a+w /root/.cargo && \
cargo install sccache --locked
ENV PATH="/root/.cargo/bin:$PATH"
RUN cargo install cargo-chef
# Cache dependencies
COPY --from=planner /usr/src/text-generation-inference/backends/trtllm/recipe.json .
RUN cargo chef cook --release --recipe-path recipe.json
# Build actual TGI
ARG CUDA_ARCH_LIST
ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt:$CMAKE_PREFIX_PATH"
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH" ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig:$PKG_CONFIG_PATH" ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig"
ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt"
COPY . . ENV USE_LLD_LINKER=ON
ENV CUDA_ARCH_LIST=${cuda_arch_list}
# SCCACHE Specifics args - before finding a better, more generic, way...
ENV SCCACHE_GHA_ENABLED=${sccache_gha_enabled}
ENV ACTIONS_CACHE_URL=${actions_cache_url}
ENV ACTIONS_RUNTIME_TOKEN=${actions_runtime_token}
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY router router
COPY backends backends
COPY benchmark benchmark
COPY launcher launcher
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \
cd backends/trtllm && \
CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release
FROM nvidia/cuda:12.6.3-cudnn-runtime-ubuntu24.04 AS runtime ENV RUSTC_WRAPPER=sccache
ENV CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX
RUN export CC=gcc-14 \
export CXX=g++-14 \
export CMAKE_C_COMPILER_LAUNCHER=sccache && \
export CMAKE_CXX_COMPILER_LAUNCHER=sccache && \
export CMAKE_CUDA_COMPILER_LAUNCHER=sccache && \
mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \
cargo build --profile ${build_type} --package text-generation-backends-trtllm --bin text-generation-backends-trtllm && \
sccache --show-stats
FROM nvidia/cuda:${cuda_base}-cudnn-runtime-ubuntu24.04 AS runtime
RUN apt update && apt install -y libucx0 pipx python3-minimal python3-dev python3-pip python3-venv && \ RUN apt update && apt install -y libucx0 pipx python3-minimal python3-dev python3-pip python3-venv && \
rm -rf /var/lib/{apt,dpkg,cache,log}/ && \ rm -rf /var/lib/{apt,dpkg,cache,log}/ && \
pipx ensurepath && \ pipx ensurepath && \
@ -104,10 +127,33 @@ COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
# This is used only for the CI/CD
FROM nvidia/cuda:${cuda_base}-cudnn-runtime-ubuntu24.04 AS ci-runtime
RUN apt update && apt install -y libasan8 libubsan1 libucx0 pipx python3-minimal python3-dev python3-pip python3-venv && \
rm -rf /var/lib/{apt,dpkg,cache,log}/ && \
pipx ensurepath && \
pipx install --include-deps transformers tokenizers
WORKDIR /usr/local/tgi/bin
ENV PATH=/root/.local/share/pipx/venvs/transformers/bin/:$PATH
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/mpi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
ENV TOKENIZERS_PARALLELISM=false
ENV OMPI_MCA_plm_rsh_agent=""
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
# Basically we copy from target/debug instead of target/release
COPY --from=tgi-builder /usr/src/text-generation-inference/target/debug/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
# This is the final image
FROM runtime FROM runtime
LABEL co.huggingface.vendor="Hugging Face Inc." LABEL co.huggingface.vendor="Hugging Face Inc."
LABEL org.opencontainers.image.authors="hardware@hf.co" LABEL org.opencontainers.image.authors="hardware@hf.co"
LABEL org.opencontainers.title="Text-Generation-Inference TensorRT-LLM Backend"
ENTRYPOINT ["./text-generation-launcher"] ENTRYPOINT ["./text-generation-launcher"]
CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"] CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"]

View File

@ -84,7 +84,7 @@ model=HuggingFaceH4/zephyr-7b-beta
volume=$PWD/data volume=$PWD/data
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
3.0.0 ghcr.io/huggingface/text-generation-inference:3.0.0 --model-id $model ghcr.io/huggingface/text-generation-inference:3.1.0 --model-id $model
``` ```
And then you can make requests like And then you can make requests like
@ -121,7 +121,7 @@ curl localhost:8080/v1/chat/completions \
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. **Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.0-rocm --model-id $model` instead of the command above. **Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.0-rocm --model-id $model` instead of the command above.
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli): To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
``` ```
@ -151,7 +151,8 @@ model=meta-llama/Meta-Llama-3.1-8B-Instruct
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
token=<your cli READ token> token=<your cli READ token>
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.0 --model-id $model docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.1.0 --model-id $model
``` ```
### A note on Shared Memory (shm) ### A note on Shared Memory (shm)

View File

@ -8,7 +8,7 @@ use tracing_opentelemetry::OpenTelemetrySpanExt;
/// Inject context in the metadata of a gRPC request. /// Inject context in the metadata of a gRPC request.
struct MetadataInjector<'a>(pub &'a mut tonic::metadata::MetadataMap); struct MetadataInjector<'a>(pub &'a mut tonic::metadata::MetadataMap);
impl<'a> Injector for MetadataInjector<'a> { impl Injector for MetadataInjector<'_> {
/// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs /// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs
fn set(&mut self, key: &str, value: String) { fn set(&mut self, key: &str, value: String) {
if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) { if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) {

View File

@ -1,13 +1,5 @@
cmake_minimum_required(VERSION 3.20) cmake_minimum_required(VERSION 3.20)
if (NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER AND CMAKE_BUILD_TYPE STREQUAL "Debug")
find_program(CCACHE_EXECUTABLE "ccache")
if (CCACHE_EXECUTABLE)
message(STATUS "Using ccache")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_EXECUTABLE}" CACHE PATH "Path to ccache" FORCE)
endif ()
endif ()
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0") if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
cmake_policy(SET CMP0135 NEW) cmake_policy(SET CMP0135 NEW)
endif () endif ()
@ -21,6 +13,7 @@ include(CheckCXXCompilerFlag)
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF) option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF) option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
option(TGI_TRTLLM_BACKEND_BUILD_USE_LLD "Enable lld linker instead of ld" OFF)
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support") set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located") set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located")
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located") set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located")
@ -28,21 +21,23 @@ set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE ST
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features # We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml) find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
find_package(MPI REQUIRED)
#### External dependencies #### #### External dependencies ####
include(cmake/json.cmake) include(cmake/json.cmake)
include(cmake/spdlog.cmake) include(cmake/spdlog.cmake)
include(cmake/trtllm.cmake) include(cmake/trtllm.cmake)
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(TGI_TRTLLM_BACKEND_DEBUG ON)
add_compile_definitions(TGI_TRTLLM_BACKEND_DEBUG=1) add_compile_definitions(TGI_TRTLLM_BACKEND_DEBUG=1)
endif() add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_TRACE)
endif ()
# This attempt to detect if the compiler can emit warning if it can't apply return value optimization from a function if (${TGI_TRTLLM_BACKEND_BUILD_USE_LLD})
check_cxx_compiler_flag("-Wnrvo" COMPILER_SUPPORT_WARNING_ON_NVRO) message(STATUS "Using lld linker")
if(${COMPILER_SUPPORT_WARNING_ON_NVRO}) add_link_options("-fuse-ld=lld")
set(CMAKE_CXX_FLAGS "{CMAKE_CXX_FLAGS} -Wnvro") endif ()
endif()
# Let's build TRTLLM as part of CMake # Let's build TRTLLM as part of CMake
add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..") add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..")
@ -55,51 +50,71 @@ add_library(tgi_trtllm_backend_impl STATIC csrc/hardware.hpp csrc/backend.hpp cs
include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR}) include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
target_include_directories(tgi_trtllm_backend_impl PRIVATE target_include_directories(tgi_trtllm_backend_impl PRIVATE
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/csrc> $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/csrc>
# $<INSTALL_INTERFACE:csrc> # $<INSTALL_INTERFACE:csrc>
) )
target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include") target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
target_link_libraries(tgi_trtllm_backend_impl PRIVATE CUDA::cudart CUDA::nvml) target_link_libraries(tgi_trtllm_backend_impl PRIVATE CUDA::cudart CUDA::nvml)
target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog) target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog)
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper)
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug")
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm)
else()
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapperm)
endif ()
# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back # This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back
install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker) install(TARGETS tgi_trtllm_backend_impl)
install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB) #install(TARGETS cutlass_src fb_gemm_src fpA_intB_gemm_src gemm_swiglu_sm90_src kernels_src)
install(TARGETS decoder_attention_0 decoder_attention_1)
install(TARGETS tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention_src executorWorker)
install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} TYPE LIB)
if (NOT ${TGI_TRTLLM_BACKEND_DEBUG})
install(FILES ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB)
endif ()
#### Unit Tests #### #### Unit Tests ####
if (${TGI_TRTLLM_BACKEND_BUILD_TESTS}) if (${TGI_TRTLLM_BACKEND_BUILD_TESTS} AND CMAKE_BUILD_TYPE MATCHES "Debug")
message(STATUS "Building tests") message(STATUS "Building tests")
option(TGI_TRTLLM_BACKEND_ENABLE_ASAN "Enable AddressSanitizer")
option(TGI_TRTLLM_BACKEND_ENABLE_UBSAN "Enable UndefinedSanitizer")
FetchContent_Declare( FetchContent_Declare(
Catch2 Catch2
URL https://github.com/catchorg/Catch2/archive/refs/tags/v3.7.1.tar.gz URL https://github.com/catchorg/Catch2/archive/refs/tags/v3.7.1.tar.gz
) )
FetchContent_MakeAvailable(Catch2) FetchContent_MakeAvailable(Catch2)
# This attempt to detect if the compiler can emit warning if it can't apply return value optimization from a function
check_cxx_compiler_flag("-Wnrvo" COMPILER_SUPPORT_WARNING_ON_NVRO)
if (${COMPILER_SUPPORT_WARNING_ON_NVRO})
message(STATUS "Enabling non-NVRO detection")
target_compile_options(tgi_trtllm_backend_impl PRIVATE -Wnrvo)
endif ()
target_compile_options(tgi_trtllm_backend_impl PRIVATE -Wall)
cmake_path(GET TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH PARENT_PATH TRTLLM_NVRTC_WRAPPER_PARENT_LIBRARY_PATH)
message(STATUS "Adding linking path: ${TRTLLM_NVRTC_WRAPPER_PARENT_LIBRARY_PATH}")
add_executable(tgi_trtllm_backend_tests tests/test_hardware.cpp tests/test_backend.cpp) add_executable(tgi_trtllm_backend_tests tests/test_hardware.cpp tests/test_backend.cpp)
# target_compile_options(tgi_trtllm_backend_tests PRIVATE -Werror)
target_link_directories(tgi_trtllm_backend_tests PRIVATE "${TRTLLM_NVRTC_WRAPPER_PARENT_LIBRARY_PATH}")
target_include_directories(tgi_trtllm_backend_tests PUBLIC "${trtllm_SOURCE_DIR}/cpp/include") target_include_directories(tgi_trtllm_backend_tests PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
target_include_directories(tgi_trtllm_backend_tests PUBLIC "csrc/") target_include_directories(tgi_trtllm_backend_tests PUBLIC "csrc/")
target_link_libraries(tgi_trtllm_backend_tests PRIVATE ${TRTLLM_LIBS} CUDA::cudart CUDA::nvml) target_link_libraries(tgi_trtllm_backend_tests PRIVATE ${TRTLLM_LIBS} CUDA::cudart CUDA::nvml)
target_link_libraries(tgi_trtllm_backend_tests PUBLIC Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog tgi_trtllm_backend_impl) target_link_libraries(tgi_trtllm_backend_tests PUBLIC Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog tgi_trtllm_backend_impl)
target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper)
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") if (${TGI_TRTLLM_BACKEND_ENABLE_ASAN})
target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm) message(STATUS "Enabled AddressSanitizer")
else() target_link_options(tgi_trtllm_backend_tests BEFORE PUBLIC -fsanitize=address)
target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapperm)
endif () endif ()
if(CMAKE_BUILD_TYPE MATCHES "Debug") if (${TGI_TRTLLM_BACKEND_ENABLE_UBSAN})
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror -fsanitize=undefined -fsanitize=address") message(STATUS "Enabled UndefinedSanitizer")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -fsanitize=undefined -fsanitize=address") target_link_options(tgi_trtllm_backend_tests BEFORE PUBLIC -fsanitize=undefined)
target_link_options(tgi_trtllm_backend_tests BEFORE PUBLIC -fsanitize=undefined PUBLIC -fsanitize=address) endif ()
endif()
list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras) install(TARGETS tgi_trtllm_backend_tests)
include(CTest)
include(Catch) # list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
catch_discover_tests(tgi_trtllm_backend_tests) # include(CTest)
# include(Catch)
# catch_discover_tests(tgi_trtllm_backend_tests)
endif () endif ()

View File

@ -7,20 +7,16 @@ homepage.workspace = true
[dependencies] [dependencies]
async-trait = "0.1" async-trait = "0.1"
#async-stream = "0.3"
clap = { version = "4.5", features = ["derive"] } clap = { version = "4.5", features = ["derive"] }
cxx = "1.0" cxx = "1.0"
hashbrown = "0.14" hashbrown = "0.15"
hf-hub = { workspace = true } hf-hub = { workspace = true }
#log = { version = "0.4", features = [] }
text-generation-router = { path = "../../router" } text-generation-router = { path = "../../router" }
tokenizers = { workspace = true } tokenizers = { workspace = true }
tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.43.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.15" tokio-stream = "0.1.17"
thiserror = "1.0.63" thiserror = "1.0.63"
tracing = "0.1" tracing = "0.1"
#tracing-opentelemetry = "0.25"
#tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
pyo3 = { workspace = true } pyo3 = { workspace = true }
[build-dependencies] [build-dependencies]

View File

@ -3,24 +3,34 @@ use pkg_config;
use std::env; use std::env;
use std::env::consts::ARCH; use std::env::consts::ARCH;
use std::path::{absolute, PathBuf}; use std::path::{absolute, PathBuf};
use std::sync::LazyLock;
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 1] = ["spdlog"]; const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 1] = ["spdlog"];
const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST"); const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST");
const CUDA_REQUIRED_VERSION: &str = "12.6"; const CUDA_REQUIRED_VERSION: &str = "12.8";
const MPI_REQUIRED_VERSION: &str = "4.1"; const MPI_REQUIRED_VERSION: &str = "4.1";
const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX"); const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX");
const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR"); const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR");
const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR"); const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR");
const IS_GHA_BUILD: LazyLock<bool> = LazyLock::new(|| {
option_env!("SCCACHE_GHA_ENABLED").map_or(false, |value| match value.to_lowercase().as_str() {
"on" => true,
"true" => true,
"1" => true,
_ => false,
})
});
// Dependencies // Dependencies
const BACKEND_DEPS: [&str; 2] = ["tgi_trtllm_backend_impl", "tgi_trtllm_backend"]; const BACKEND_DEPS: &str = "tgi_trtllm_backend_impl";
const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"]; const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"];
const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [ const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [
("dylib", "tensorrt_llm"), ("dylib", "tensorrt_llm"),
("static", "tensorrt_llm_executor_static"),
("dylib", "tensorrt_llm_nvrtc_wrapper"), ("dylib", "tensorrt_llm_nvrtc_wrapper"),
("dylib", "nvinfer_plugin_tensorrt_llm"), ("dylib", "nvinfer_plugin_tensorrt_llm"),
("dylib", "decoder_attention"), ("dylib", "decoder_attention_0"),
("dylib", "decoder_attention_1"),
]; ];
macro_rules! probe { macro_rules! probe {
@ -32,6 +42,48 @@ macro_rules! probe {
}; };
} }
fn get_compiler_flag(
switch: bool,
true_case: &'static str,
false_case: &'static str,
) -> &'static str {
match switch {
true => true_case,
false => false_case,
}
}
fn get_library_architecture() -> &'static str {
let os = env::var("CARGO_CFG_TARGET_OS").unwrap();
let arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap();
let env = env::var("CARGO_CFG_TARGET_ENV").unwrap();
match os.as_str() {
"linux" => {
if env != "gnu" {
panic!("unsupported linux ABI {env}, only 'gnu' is supported")
}
match arch.as_str() {
"x86_64" => "x86_64-linux-gnu",
"aarch64" => "aarch64-linux-gnu",
_ => panic!("unsupported linux architecture {arch}"),
}
}
"windows" => {
if env != "msvc" {
panic!("unsupported windows ABI {env}, only 'msvc' is supported")
}
match arch.as_str() {
"x86_64" => "x86_64-windows-msvc",
_ => panic!("unsupported windows architecture {arch}"),
}
}
_ => panic!("unsupported OS {os}"),
}
}
fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) { fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) {
// Build the backend implementation through CMake // Build the backend implementation through CMake
let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi"); let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
@ -54,10 +106,44 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
.env("OPT_LEVEL", opt_level) .env("OPT_LEVEL", opt_level)
.define("CMAKE_INSTALL_PREFIX", &install_path) .define("CMAKE_INSTALL_PREFIX", &install_path)
.define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc") .define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc")
.define("Python3_ROOT_DIR", "../venv") .define("CMAKE_LIBRARY_ARCHITECTURE", get_library_architecture())
.define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list) .define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list)
.define(
"TGI_TRTLLM_BACKEND_DEBUG",
get_compiler_flag(is_debug, "ON", "OFF"),
)
.define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path); .define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path);
if is_debug || *IS_GHA_BUILD {
config.define("TGI_TRTLLM_BACKEND_BUILD_TESTS", "ON");
}
if option_env!("USE_LLD_LINKER").is_some() {
println!("cargo:warning=Using lld linker");
config.define("TGI_TRTLLM_BACKEND_BUILD_USE_LLD", "ON");
}
if (is_debug && option_env!("ENABLE_ASAN").is_some()) || *IS_GHA_BUILD {
println!("cargo:warning=Enabling Address Sanitizer");
config.define("TGI_TRTLLM_BACKEND_ENABLE_ASAN", "ON");
}
if (is_debug && option_env!("ENABLE_UBSAN").is_some()) || *IS_GHA_BUILD {
println!("cargo:warning=Enabling Undefined Sanitizer");
config.define("TGI_TRTLLM_BACKEND_ENABLE_UBSAN", "ON");
}
if let Some(nvcc_host_compiler) = option_env!("CMAKE_CUDA_HOST_COMPILER") {
config.define("CMAKE_CUDA_HOST_COMPILER", nvcc_host_compiler);
}
if let Some(wrapper) = option_env!("RUSTC_WRAPPER") {
println!("cargo:warning=Using caching tool: {wrapper}");
config.define("CMAKE_C_COMPILER_LAUNCHER", wrapper);
config.define("CMAKE_CXX_COMPILER_LAUNCHER", wrapper);
config.define("CMAKE_CUDA_COMPILER_LAUNCHER", wrapper);
}
// Allow to override which Python to use ... // Allow to override which Python to use ...
if let Some(python3) = option_env!("Python3_EXECUTABLE") { if let Some(python3) = option_env!("Python3_EXECUTABLE") {
config.define("Python3_EXECUTABLE", python3); config.define("Python3_EXECUTABLE", python3);
@ -78,23 +164,18 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
} }
// Emit linkage information from the artifacts we just built // Emit linkage information from the artifacts we just built
let install_lib_path = install_path.join("lib"); for path in ["lib", "lib64"] {
let install_lib_path = install_path.join(path);
println!( println!(
r"cargo:warning=Adding link search path: {}", r"cargo:warning=Adding link search path: {}",
install_lib_path.display() install_lib_path.display()
); );
println!(r"cargo:rustc-link-search={}", install_lib_path.display()); println!(r"cargo:rustc-link-search={}", install_lib_path.display());
}
(PathBuf::from(install_path), deps_folder) (PathBuf::from(install_path), deps_folder)
} }
fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) { fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) {
let ndebug = match is_debug {
true => "1",
false => "0",
};
CFG.include_prefix = "backends/trtllm"; CFG.include_prefix = "backends/trtllm";
cxx_build::bridge("src/lib.rs") cxx_build::bridge("src/lib.rs")
.static_flag(true) .static_flag(true)
@ -106,7 +187,10 @@ fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) {
.include("/usr/local/tensorrt/include") .include("/usr/local/tensorrt/include")
.include("csrc/") .include("csrc/")
.file("csrc/ffi.hpp") .file("csrc/ffi.hpp")
.define("TGI_TRTLLM_BACKEND_DEBUG", ndebug) .define(
"TGI_TRTLLM_BACKEND_DEBUG",
get_compiler_flag(is_debug, "ON", "OFF"),
)
.compile("tgi_trtllm_backend"); .compile("tgi_trtllm_backend");
println!("cargo:rerun-if-changed=CMakeLists.txt"); println!("cargo:rerun-if-changed=CMakeLists.txt");
@ -125,6 +209,7 @@ fn main() {
let build_profile = env::var("PROFILE").unwrap(); let build_profile = env::var("PROFILE").unwrap();
let (is_debug, opt_level) = match build_profile.as_ref() { let (is_debug, opt_level) = match build_profile.as_ref() {
"debug" => (true, "0"), "debug" => (true, "0"),
"dev" => (true, "0"),
_ => (false, "3"), _ => (false, "3"),
}; };
@ -161,7 +246,5 @@ fn main() {
}); });
// Backend // Backend
BACKEND_DEPS.iter().for_each(|name| { println!("cargo:rustc-link-lib=static={}", &BACKEND_DEPS);
println!("cargo:rustc-link-lib=static={}", name);
});
} }

View File

@ -4,14 +4,14 @@ set(SPDLOG_FMT_EXTERNAL OFF)
# Define the level at which SPDLOG_ compilation level is defined # Define the level at which SPDLOG_ compilation level is defined
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG) add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_TRACE)
else () else ()
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO) add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG)
endif () endif ()
fetchcontent_declare( fetchcontent_declare(
spdlog spdlog
# DOWNLOAD_EXTRACT_TIMESTAMP # DOWNLOAD_EXTRACT_TIMESTAMP
URL https://github.com/gabime/spdlog/archive/refs/tags/v1.14.1.tar.gz URL https://github.com/gabime/spdlog/archive/refs/tags/v1.15.0.tar.gz
) )
fetchcontent_makeavailable(spdlog) fetchcontent_makeavailable(spdlog)

View File

@ -14,19 +14,21 @@ message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}")
set(ENABLE_UCX OFF) set(ENABLE_UCX OFF)
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
set(FAST_BUILD ON) set(FAST_BUILD ON)
set(NVTX_DISABLE OFF) set(NVTX_DISABLE ON)
set(INDEX_RANGE_CHECK ON)
else () else ()
set(FAST_BUILD OFF) set(FAST_BUILD OFF)
set(FAST_MATH ON) set(FAST_MATH ON)
set(NVTX_DISABLE ON) set(NVTX_DISABLE OFF)
set(INDEX_RANGE_CHECK OFF)
endif () endif ()
find_package(Python3 REQUIRED Interpreter) find_package(Python3 REQUIRED Interpreter)
fetchcontent_declare( fetchcontent_declare(
trtllm trtllm
GIT_REPOSITORY https://github.com/huggingface/TensorRT-LLM.git GIT_REPOSITORY https://github.com/nvidia/TensorRT-LLM.git
GIT_TAG 1bb9ca4688805444f203647674bac1d7219d0579 GIT_TAG v0.17.0
GIT_SHALLOW ON GIT_SHALLOW ON
DOWNLOAD_EXTRACT_TIMESTAMP DOWNLOAD_EXTRACT_TIMESTAMP
) )

View File

@ -1,7 +1,6 @@
#include <ranges> #include <ranges>
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#include <spdlog/spdlog.h>
#include "backend.hpp" #include "backend.hpp"
#include "hardware.hpp" #include "hardware.hpp"
@ -17,7 +16,8 @@ namespace huggingface::tgi::backends::trtllm {
if (world_size > 1) { if (world_size > 1) {
SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode"); SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
mode = tle::CommunicationMode::kORCHESTRATOR; mode = tle::CommunicationMode::kORCHESTRATOR;
orchestratorConfig = std::make_optional<tle::OrchestratorConfig>(true, executor_worker_path_, nullptr, true); orchestratorConfig = std::make_optional<tle::OrchestratorConfig>(true, executor_worker_path_, nullptr,
true);
} else { } else {
SPDLOG_INFO("Detected single engine deployment, using leader mode"); SPDLOG_INFO("Detected single engine deployment, using leader mode");
} }
@ -51,14 +51,15 @@ namespace huggingface::tgi::backends::trtllm {
} }
std::expected<request_id_t, backend_error_t> std::expected<request_id_t, backend_error_t>
backend_t::submit(std::span<const token_id_t> token_ids, const generation_params_t generation_params, const sampling_params_t sampling_params) noexcept { backend_t::submit(std::span<const token_id_t> token_ids, const generation_params_t g_params,
SPDLOG_DEBUG("Submitting {:d} tokens to the executor for scheduling ({}, {})", token_ids.size(), generation_params, sampling_params); const sampling_params_t s_params) noexcept {
return executor_.enqueueRequest(tle::Request { SPDLOG_DEBUG("Submit {:d} tokens for scheduling ({}, {})", token_ids.size(), g_params, s_params);
return executor_.enqueueRequest(tle::Request{
{token_ids.begin(), token_ids.end()}, // Making actual copy of the tokens {token_ids.begin(), token_ids.end()}, // Making actual copy of the tokens
static_cast<tle::SizeType32>(generation_params.max_new_tokens), static_cast<tle::SizeType32>(g_params.max_new_tokens),
true, true,
(tle::SamplingConfig) sampling_params, (tle::SamplingConfig) s_params,
tle::OutputConfig { /* returnLogProbs= */ true }, tle::OutputConfig{ /* returnLogProbs= */ true},
std::nullopt, std::nullopt,
std::nullopt, std::nullopt,
std::nullopt, std::nullopt,

View File

@ -28,9 +28,53 @@ namespace huggingface::tgi::backends::trtllm {
#include "backends/trtllm/src/lib.rs.h" #include "backends/trtllm/src/lib.rs.h"
namespace huggingface::tgi::backends::trtllm { namespace huggingface::tgi::backends::trtllm {
std::once_flag backend_initialized_flag; std::once_flag backend_initialized_flag;
constexpr finish_reason_t as_finish_reason_t(const tle::FinishReason reason) noexcept {
switch (reason) {
case tle::FinishReason::kNOT_FINISHED:
return finish_reason_t::kNOT_FINISHED;
case tle::FinishReason::kSTOP_WORDS:
return finish_reason_t::kSTOP_WORDS;
case tle::FinishReason::kEND_ID:
return finish_reason_t::kEND_ID;
case tle::FinishReason::kLENGTH:
return finish_reason_t::kLENGTH;
default:
std::unreachable();
}
}
static auto as_generation_step = [](const tle::Response &r) {
const auto reqId = r.getRequestId();
if (!r.hasError()) [[likely]] {
const auto result = r.getResult();
const auto logits = result.logProbs.value()[0];
return generation_step_t{
reqId,
static_cast<uint32_t>(result.outputTokenIds[0][0]),
logits.back(),
result.isFinal,
as_finish_reason_t(result.finishReasons[0]),
false,
std::string()
};
} else {
return generation_step_t{
reqId,
0,
0.0,
true,
finish_reason_t::kNOT_FINISHED,
true,
std::move(r.getErrorMsg())
};
}
};
class tensorrt_llm_backend_t { class tensorrt_llm_backend_t {
private: private:
backend_t inner_; backend_t inner_;
@ -39,9 +83,7 @@ namespace huggingface::tgi::backends::trtllm {
tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path) tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path)
: inner_(engine_folder, executor_worker_path) {} : inner_(engine_folder, executor_worker_path) {}
size_t num_tokens_ready() const noexcept { size_t num_tokens_ready() const noexcept { return inner_.num_tokens_ready(); }
return inner_.num_tokens_ready();
}
request_id_t submit( request_id_t submit(
rust::Slice<const uint32_t> tokens, rust::Slice<const uint32_t> tokens,
@ -65,7 +107,7 @@ namespace huggingface::tgi::backends::trtllm {
); );
// If we do have a value, let's return the request_id // If we do have a value, let's return the request_id
if(maybe_request_id.has_value()) [[likely]] { if (maybe_request_id.has_value()) [[likely]] {
return *maybe_request_id; return *maybe_request_id;
} else { } else {
SPDLOG_WARN("[FFI] Failed to submit request to the executor"); SPDLOG_WARN("[FFI] Failed to submit request to the executor");
@ -74,45 +116,29 @@ namespace huggingface::tgi::backends::trtllm {
} }
std::unique_ptr<std::vector<generation_step_t>> pull_tokens() noexcept { std::unique_ptr<std::vector<generation_step_t>> pull_tokens() noexcept {
if(num_tokens_ready() > 0) [[likely]] { if (num_tokens_ready() > 0) [[likely]] {
const auto responses = inner_.pull_tokens(); const auto responses = inner_.pull_tokens();
SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size()); SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size());
// Transform tle::Response to GenerationStep
auto steps = std::make_unique<std::vector<generation_step_t>>(); // Transform tle::Response to generation_step_t
std::ranges::transform(responses.begin(), responses.end(), std::back_inserter(*steps), [](const tle::Response &r) { #ifdef __cpp_lib_ranges_to_container
const auto reqId = r.getRequestId(); auto steps = responses | std::views::transform(as_generation_step) | std::ranges::to<std::vector>();
if (!r.hasError()) [[likely]] { #else
const auto result = r.getResult(); auto steps = std::vector<generation_step_t>();
return generation_step_t{ steps.reserve(responses.size());
reqId, std::transform(responses.begin(), responses.end(), std::back_inserter(steps), as_generation_step);
static_cast<uint32_t>(result.outputTokenIds[0][0]), #endif
result.logProbs.value()[0][0], return std::make_unique<std::vector<generation_step_t>>(steps);
result.isFinal,
false,
std::string()
};
} else {
return generation_step_t{
reqId,
0,
0.0,
true,
true,
std::move(r.getErrorMsg())
};
}
});
return steps;
} else { } else {
return std::make_unique<std::vector<generation_step_t>>(); return std::make_unique<std::vector<generation_step_t>>();
} }
} }
void cancel(request_id_t requestId) noexcept { void cancel(request_id_t request_id) noexcept {
SPDLOG_DEBUG("[FFI] cancelling request {:d}", requestId); SPDLOG_DEBUG("[FFI] cancelling request {:d}", request_id);
inner_.cancel(requestId); inner_.cancel(request_id);
} }
}; };
@ -151,11 +177,14 @@ namespace huggingface::tgi::backends::trtllm {
} }
} }
std::unique_ptr<tensorrt_llm_backend_t> create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) { std::unique_ptr<tensorrt_llm_backend_t>
create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) {
std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend); std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend);
return std::make_unique<tensorrt_llm_backend_t>( return std::make_unique<tensorrt_llm_backend_t>(
std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()), std::filesystem::path::format::auto_format), std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()),
std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()), std::filesystem::path::format::auto_format) std::filesystem::path::format::auto_format),
std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()),
std::filesystem::path::format::auto_format)
); );
} }
} }

View File

@ -2,13 +2,13 @@
set -ex set -ex
TRT_VER_BASE="10.6.0" TRT_VER_BASE="10.8.0"
TRT_VER_FULL="${TRT_VER_BASE}.26" TRT_VER_FULL="${TRT_VER_BASE}.43"
CUDA_VER="12.6" CUDA_VER="12.8"
CUDNN_VER="9.5.0.50-1" CUDNN_VER="9.7.0.66-1"
NCCL_VER="2.22.3-1+cuda12.6" NCCL_VER="2.25.1-1+cuda${CUDA_VER}"
CUBLAS_VER="12.6.3.3-1" CUBLAS_VER="${CUDA_VER}.3.14-1"
NVRTC_VER="12.6.77-1" NVRTC_VER="${CUDA_VER}.61-1"
for i in "$@"; do for i in "$@"; do
case $i in case $i in
@ -73,7 +73,7 @@ install_centos_requirements() {
install_tensorrt() { install_tensorrt() {
#PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))') #PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
#PARSED_PY_VERSION=$(echo "${PY_VERSION//./}") #PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
TRT_CUDA_VERSION="12.6" TRT_CUDA_VERSION="12.8"
if [ -z "$RELEASE_URL_TRT" ];then if [ -z "$RELEASE_URL_TRT" ];then
ARCH=${TRT_TARGETARCH} ARCH=${TRT_TARGETARCH}

View File

@ -0,0 +1,51 @@
from argparse import ArgumentParser
AWS_S3_CACHING_VARIABLES = {
"AWS_ACCESS_KEY_ID": "aws_access_key_id",
"AWS_SECRET_ACCESS_KEY": "aws_secret_access_key",
"AWS_SESSION_TOKEN": "aws_session_token",
"SCCACHE_REGION": "s3_region",
"SCCACHE_BUCKET": "s3_bucket_name",
}
ALL_CACHING_STORAGE_VARIABLES = {"AWS_S3_CACHING_VARIABLES"}
def setup_sccache_locally():
from os import environ
print("Setting up Local Caching Layer")
for target in ALL_CACHING_STORAGE_VARIABLES:
for envvar in globals()[target].keys():
if envvar in environ:
print(f"Deleted {envvar} from environment variables")
del environ[envvar]
def setup_sccache_for_s3():
from os import environ
print("Setting up AWS S3 Caching Layer")
for envvar in AWS_S3_CACHING_VARIABLES.keys():
if envvar not in environ or not environ[envvar] or len(environ[envvar]) == 0:
print(f"Missing definition for environment variable {envvar}")
if __name__ == "__main__":
parser = ArgumentParser("TensorRT-LLM Build Caching Setup")
parser.add_argument(
"--is-gha-build",
type=str,
default="FALSE",
help="Indicate if the build is from Github Actions",
)
# Parse args
args = parser.parse_args()
args.is_gha_build = args.is_gha_build.lower() in {"on", "true", "1"}
if args.is_gha_build:
setup_sccache_for_s3()
else:
setup_sccache_locally()

View File

@ -6,6 +6,26 @@ mod utils;
#[cxx::bridge(namespace = "huggingface::tgi::backends::trtllm")] #[cxx::bridge(namespace = "huggingface::tgi::backends::trtllm")]
mod ffi { mod ffi {
#[cxx_name = "finish_reason_t"]
#[derive(Debug, Clone, Copy)]
pub enum FinishReason {
/// The request is not finished.
#[cxx_name = "kNOT_FINISHED"]
NotFinished = 0u8,
/// The request finished because the end id was generated.
#[cxx_name = "kEND_ID"]
EndTokenId = 1u8,
/// The request finished because a stop word was generated.
#[cxx_name = "kSTOP_WORDS"]
StopWords = 2u8,
/// The request finished because the maximum number of tokens was reached.
#[cxx_name = "kLENGTH"]
MaxLength = 3u8,
}
/// Struct used as shared type between rust and C++ to represent the result /// Struct used as shared type between rust and C++ to represent the result
/// of a single decoding iteration /// of a single decoding iteration
#[cxx_name = "generation_step_t"] #[cxx_name = "generation_step_t"]
@ -15,6 +35,7 @@ mod ffi {
token_id: u32, token_id: u32,
log_prob: f32, log_prob: f32,
is_final: bool, is_final: bool,
finish_reason: FinishReason,
has_error: bool, has_error: bool,
error_msg: String, error_msg: String,
} }
@ -66,3 +87,17 @@ mod ffi {
fn cancel(self: Pin<&mut TensorRtLlmBackendImpl>, request_id: u64); fn cancel(self: Pin<&mut TensorRtLlmBackendImpl>, request_id: u64);
} }
} }
use ffi::FinishReason;
use text_generation_router::FinishReason as InferFinishReason;
impl From<FinishReason> for InferFinishReason {
fn from(reason: FinishReason) -> Self {
match reason {
FinishReason::StopWords => InferFinishReason::StopSequence,
FinishReason::MaxLength => InferFinishReason::Length,
FinishReason::EndTokenId => InferFinishReason::EndOfSequenceToken,
_ => panic!("Cannot convert {reason:?} to text_generation_router::FinishReason"),
}
}
}

View File

@ -18,10 +18,12 @@ use text_generation_router::validation::ValidationError::{
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality, EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
}; };
use text_generation_router::validation::{Chunk, ValidGenerateRequest}; use text_generation_router::validation::{Chunk, ValidGenerateRequest};
use text_generation_router::{FinishReason, Token}; use text_generation_router::Token;
use crate::errors::TensorRtLlmBackendError; use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_backend_from_engine_folder, GenerationStep, TensorRtLlmBackendImpl}; use crate::ffi::{
create_backend_from_engine_folder, FinishReason, GenerationStep, TensorRtLlmBackendImpl,
};
use crate::utils::first_line; use crate::utils::first_line;
type InferResult<T> = Result<T, InferError>; type InferResult<T> = Result<T, InferError>;
@ -40,6 +42,7 @@ struct DecodedToken {
id: u32, id: u32,
log_prob: f32, log_prob: f32,
is_final: bool, is_final: bool,
finish_reason: FinishReason,
} }
impl<'step> TryFrom<&'step GenerationStep> for DecodedToken { impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
@ -51,6 +54,7 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
id: step.token_id, id: step.token_id,
log_prob: step.log_prob, log_prob: step.log_prob,
is_final: step.is_final, is_final: step.is_final,
finish_reason: step.finish_reason,
}) })
} else { } else {
Err(GenerationError(step.error_msg.clone())) Err(GenerationError(step.error_msg.clone()))
@ -192,7 +196,7 @@ fn post_process_decoded_token(
let generated_text = GeneratedText { let generated_text = GeneratedText {
text: text.unwrap(), text: text.unwrap(),
generated_tokens: ctx.tokens.len() as u32, generated_tokens: ctx.tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken, // TODO : Map FinishReason finish_reason: decoded_token.finish_reason.into(),
seed: None, seed: None,
}; };
@ -336,4 +340,8 @@ impl Backend for TensorRtLlmBackendV2 {
async fn health(&self, _: bool) -> bool { async fn health(&self, _: bool) -> bool {
true true
} }
fn name(&self) -> &'static str {
"TensorRT-LLM"
}
} }

View File

@ -11,7 +11,7 @@ use text_generation_router::server::{
get_hub_model_info, legacy_tokenizer_handle, py_resolve_tokenizer, get_hub_model_info, legacy_tokenizer_handle, py_resolve_tokenizer,
}; };
use text_generation_router::usage_stats::UsageStatsLevel; use text_generation_router::usage_stats::UsageStatsLevel;
use text_generation_router::{server, HubTokenizerConfig, Tokenizer}; use text_generation_router::{server, Tokenizer};
/// App Configuration /// App Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -67,11 +67,7 @@ struct Args {
payload_limit: usize, payload_limit: usize,
} }
async fn get_tokenizer( async fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option<Tokenizer> {
tokenizer_name: &str,
tokenizer_config_path: Option<&str>,
revision: Option<&str>,
) -> Option<Tokenizer> {
// Parse Huggingface hub token // Parse Huggingface hub token
let authorization_token = std::env::var("HF_TOKEN") let authorization_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")) .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
@ -182,19 +178,6 @@ async fn get_tokenizer(
} }
}; };
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
// let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
// {
// HubTokenizerConfig::from_file(filename)
// } else {
// tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
// };
// let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
// tracing::warn!("Could not find tokenizer config locally and no API specified");
// HubTokenizerConfig::default()
// });
let tokenizer: Tokenizer = { let tokenizer: Tokenizer = {
use pyo3::prelude::*; use pyo3::prelude::*;
pyo3::Python::with_gil(|py| -> PyResult<()> { pyo3::Python::with_gil(|py| -> PyResult<()> {
@ -292,11 +275,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
} }
// Create the backend // Create the backend
match get_tokenizer( match get_tokenizer(&tokenizer_name, revision.as_deref())
&tokenizer_name,
tokenizer_config_path.as_deref(),
revision.as_deref(),
)
.await .await
.expect("Failed to retrieve tokenizer implementation") .expect("Failed to retrieve tokenizer implementation")
{ {

View File

@ -8,13 +8,13 @@
#include "backend.hpp" #include "backend.hpp"
using namespace huggingface::tgi::backends::trtllm; using namespace huggingface::tgi::backends::trtllm;
TEST_CASE("parse generation_config.json all set", "[generation_config_t]") TEST_CASE("parse generation_config.json all set", "[generation_config_t]")
{ {
const json config_j = {{"temperature", 0.6}, {"top_p", 0.95}, {"eos_token_id", {1,2,3}}}; const json config_j = {{"temperature", 0.6},
{"top_p", 0.95},
{"eos_token_id", {1, 2, 3}}};
const auto generation_config = generation_config_t(config_j); const auto generation_config = generation_config_t(config_j);
REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(0.6, 1e-6)); REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(0.6, 1e-6));
@ -24,8 +24,9 @@ TEST_CASE("parse generation_config.json all set", "[generation_config_t]")
REQUIRE_FALSE(generation_config.stop_words.empty()); REQUIRE_FALSE(generation_config.stop_words.empty());
REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size()); REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());
for (auto [lhs, rhs] : std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1}, {2}, {3}})) for (auto [lhs, rhs]: std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1},
{ {2},
{3}})) {
// Currently we do not support multi-tokens stop words // Currently we do not support multi-tokens stop words
REQUIRE(lhs.size() == 1); REQUIRE(lhs.size() == 1);
REQUIRE(rhs.size() == 1); REQUIRE(rhs.size() == 1);
@ -35,7 +36,7 @@ TEST_CASE("parse generation_config.json all set", "[generation_config_t]")
TEST_CASE("parse generation_config.json default", "[generation_config_t]") TEST_CASE("parse generation_config.json default", "[generation_config_t]")
{ {
const json config_j = {{"eos_token_id", {1,2,3}}}; const json config_j = {{"eos_token_id", {1, 2, 3}}};
const auto generation_config = generation_config_t(config_j); const auto generation_config = generation_config_t(config_j);
REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6)); REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6));
@ -44,8 +45,9 @@ TEST_CASE("parse generation_config.json default", "[generation_config_t]")
REQUIRE_FALSE(generation_config.stop_words.empty()); REQUIRE_FALSE(generation_config.stop_words.empty());
REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size()); REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());
for (auto [lhs, rhs] : std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1}, {2}, {3}})) for (auto [lhs, rhs]: std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1},
{ {2},
{3}})) {
// Currently we do not support multi-tokens stop words // Currently we do not support multi-tokens stop words
REQUIRE(lhs.size() == 1); REQUIRE(lhs.size() == 1);
REQUIRE(rhs.size() == 1); REQUIRE(rhs.size() == 1);

View File

@ -108,6 +108,10 @@ impl Backend for BackendV2 {
fn start_health(&self) -> bool { fn start_health(&self) -> bool {
true true
} }
fn name(&self) -> &'static str {
"tgi-v2"
}
} }
/// Batching logic /// Batching logic

View File

@ -213,8 +213,7 @@ impl State {
} }
// Pad prefill_token_budget to be a multiple of block size // Pad prefill_token_budget to be a multiple of block size
let prefill_token_budget = let prefill_token_budget = prefill_token_budget.div_ceil(self.block_size) * self.block_size;
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
// Create span for this batch to add context to inference calls // Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
@ -245,9 +244,8 @@ impl State {
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
} else { } else {
// pad to block size // pad to block size
prefill_tokens += ((entry.request.input_length + self.block_size - 1) prefill_tokens +=
/ self.block_size) entry.request.input_length.div_ceil(self.block_size) * self.block_size;
* self.block_size;
} }
if self.requires_padding { if self.requires_padding {
@ -262,8 +260,7 @@ impl State {
}; };
// pad to block size // pad to block size
decode_tokens += decode_tokens += max_new_tokens.div_ceil(self.block_size) * self.block_size;
((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size;
} }
if prefill_tokens > prefill_token_budget if prefill_tokens > prefill_token_budget

View File

@ -115,6 +115,10 @@ impl Backend for BackendV3 {
fn start_health(&self) -> bool { fn start_health(&self) -> bool {
true true
} }
fn name(&self) -> &'static str {
"tgi-v3"
}
} }
/// Batching logic /// Batching logic

View File

@ -165,13 +165,13 @@ impl Allocator for SimpleAllocator {
let (tokens, repeats) = match self.window_size { let (tokens, repeats) = match self.window_size {
None => (tokens, 1), None => (tokens, 1),
Some(window_size) => { Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size; let repeats = tokens.div_ceil(window_size);
let tokens = core::cmp::min(tokens, window_size); let tokens = core::cmp::min(tokens, window_size);
(tokens, repeats as usize) (tokens, repeats as usize)
} }
}; };
// Pad to a multiple of block size // Pad to a multiple of block size
let required_blocks = (tokens + self.block_size - 1) / self.block_size; let required_blocks = tokens.div_ceil(self.block_size);
(required_blocks, repeats) (required_blocks, repeats)
}; };

View File

@ -257,8 +257,7 @@ impl State {
} }
// Pad prefill_token_budget to be a multiple of block size // Pad prefill_token_budget to be a multiple of block size
let prefill_token_budget = let prefill_token_budget = prefill_token_budget.div_ceil(self.block_size) * self.block_size;
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
// Create span for this batch to add context to inference calls // Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);

View File

@ -103,7 +103,7 @@ impl Allocator for RadixAllocator {
let prefix_len = blocks.len() * self.block_size as usize; let prefix_len = blocks.len() * self.block_size as usize;
let suffix_len = tokens - prefix_len as u32; let suffix_len = tokens - prefix_len as u32;
let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size; let suffix_blocks = suffix_len.div_ceil(self.block_size);
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}"); tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0", "name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0" "url": "https://www.apache.org/licenses/LICENSE-2.0"
}, },
"version": "3.0.2-dev0" "version": "3.1.1-dev0"
}, },
"paths": { "paths": {
"/": { "/": {

View File

@ -13,6 +13,8 @@
title: Using TGI with Intel Gaudi title: Using TGI with Intel Gaudi
- local: installation_inferentia - local: installation_inferentia
title: Using TGI with AWS Inferentia title: Using TGI with AWS Inferentia
- local: installation_tpu
title: Using TGI with Google TPUs
- local: installation_intel - local: installation_intel
title: Using TGI with Intel GPUs title: Using TGI with Intel GPUs
- local: installation - local: installation

View File

@ -4,8 +4,13 @@ The NVIDIA TensorRT-LLM (TRTLLM) backend is a high-performance backend for LLMs
that uses NVIDIA's TensorRT library for inference acceleration. that uses NVIDIA's TensorRT library for inference acceleration.
It makes use of specific optimizations for NVIDIA GPUs, such as custom kernels. It makes use of specific optimizations for NVIDIA GPUs, such as custom kernels.
To use the TRTLLM backend you need to compile `engines` for the models you want to use. To use the TRTLLM backend **you need to compile** `engines` for the models you want to use.
Each `engine` must be compiled on the same GPU architecture that you will use for inference. Each `engine` must be compiled for a given set of:
- GPU architecture that you will use for inference (e.g. A100, L40, etc.)
- Maximum batch size
- Maximum input length
- Maximum output length
- Maximum beams width
## Supported models ## Supported models
@ -19,63 +24,159 @@ want to use.
```bash ```bash
MODEL_NAME="meta-llama/Llama-3.1-8B-Instruct" MODEL_NAME="meta-llama/Llama-3.1-8B-Instruct"
DESTINATION="/tmp/engines/$MODEL_NAME"
# Install huggingface_cli HF_TOKEN="hf_xxx"
python -m pip install huggingface-cli[hf_transfer]
# Login to the Hugging Face Hub
huggingface-cli login
# Create a directory to store the model
mkdir -p /tmp/models/$MODEL_NAME
# Create a directory to store the compiled engine
mkdir -p /tmp/engines/$MODEL_NAME
# Download the model
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download --local-dir /tmp/models/$MODEL_NAME $MODEL_NAME
# Compile the engine using Optimum-NVIDIA # Compile the engine using Optimum-NVIDIA
# This will create a compiled engine in the /tmp/engines/meta-llama/Llama-3.1-8B-Instruct
# directory for 1 GPU
docker run \ docker run \
--rm \ --rm \
-it \ -it \
--gpus=1 \ --gpus=1 \
-v /tmp/models/$MODEL_NAME:/model \ --shm-size=1g \
-v /tmp/engines/$MODEL_NAME:/engine \ -v "$DESTINATION":/engine \
huggingface/optimum-nvidia \ -e HF_TOKEN=$HF_TOKEN \
optimum-cli export trtllm \ -e HF_HUB_ENABLE_HF_TRANSFER=1 \
huggingface/optimum-nvidia:v0.1.0b9-py310 \
bash -c "optimum-cli export trtllm \
--tp=1 \ --tp=1 \
--pp=1 \ --pp=1 \
--max-batch-size=128 \ --max-batch-size=64 \
--max-input-length 4096 \ --max-input-length 4096 \
--max-output-length 8192 \ --max-output-length 8192 \
--max-beams-width=1 \ --max-beams-width=1 \
--destination /engine \ --destination /tmp/engine \
$MODEL_NAME $MODEL_NAME && cp -rL /tmp/engine/* /engine/"
``` ```
Your compiled engine will be saved in the `/tmp/engines/$MODEL_NAME` directory. Your compiled engine will be saved in the `/tmp/engines/$MODEL_NAME` directory, in a subfolder named after the GPU used to compile the model.
## Using the TRTLLM backend ## Using the TRTLLM backend
Run TGI-TRTLLM Docker image with the compiled engine: Run TGI-TRTLLM Docker image with the compiled engine:
```bash ```bash
MODEL_NAME="meta-llama/Llama-3.1-8B-Instruct"
DESTINATION="/tmp/engines/$MODEL_NAME"
HF_TOKEN="hf_xxx"
docker run \ docker run \
--gpus 1 \ --gpus 1 \
--shm-size=1g \
-it \ -it \
--rm \ --rm \
-p 3000:3000 \ -p 3000:3000 \
-e MODEL=$MODEL_NAME \ -e MODEL=$MODEL_NAME \
-e PORT=3000 \ -e PORT=3000 \
-e HF_TOKEN='hf_XXX' \ -e HF_TOKEN=$HF_TOKEN \
-v /tmp/engines/$MODEL_NAME:/data \ -v "$DESTINATION"/<YOUR_GPU_ARCHITECTURE>/engines:/data \
ghcr.io/huggingface/text-generation-inference:latest-trtllm \ ghcr.io/huggingface/text-generation-inference:latest-trtllm \
--executor-worker executorWorker \ --model-id /data/ \
--model-id /data/$MODEL_NAME --tokenizer-name $MODEL_NAME
``` ```
## Development ## Development
To develop TRTLLM backend, you can use [dev containers](https://containers.dev/) located in To develop TRTLLM backend, you can use [dev containers](https://containers.dev/) with the following `.devcontainer.json` file:
`.devcontainer` directory. ```json
{
"name": "CUDA",
"build": {
"dockerfile": "Dockerfile_trtllm",
"context": ".."
},
"remoteEnv": {
"PATH": "${containerEnv:PATH}:/usr/local/cuda/bin",
"LD_LIBRARY_PATH": "$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64",
"XLA_FLAGS": "--xla_gpu_cuda_data_dir=/usr/local/cuda"
},
"customizations" : {
"jetbrains" : {
"backend" : "CLion"
}
}
}
```
and `Dockerfile_trtllm`:
```Dockerfile
ARG cuda_arch_list="75-real;80-real;86-real;89-real;90-real"
ARG build_type=release
ARG ompi_version=4.1.7
# CUDA dependent dependencies resolver stage
FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu24.04 AS cuda-builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
build-essential \
cmake \
curl \
gcc-14 \
g++-14 \
git \
git-lfs \
lld \
libssl-dev \
libucx-dev \
libasan8 \
libubsan1 \
ninja-build \
pkg-config \
pipx \
python3 \
python3-dev \
python3-setuptools \
tar \
wget --no-install-recommends && \
pipx ensurepath
ENV TGI_INSTALL_PREFIX=/usr/local/tgi
ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt
# Install OpenMPI
FROM cuda-builder AS mpi-builder
WORKDIR /opt/src/mpi
ARG ompi_version
ENV OMPI_VERSION=${ompi_version}
ENV OMPI_TARBALL_FILENAME=openmpi-${OMPI_VERSION}.tar.bz2
ADD --checksum=sha256:54a33cb7ad81ff0976f15a6cc8003c3922f0f3d8ceed14e1813ef3603f22cd34 \
https://download.open-mpi.org/release/open-mpi/v4.1/${OMPI_TARBALL_FILENAME} .
RUN tar --strip-components=1 -xf ${OMPI_TARBALL_FILENAME} &&\
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \
make -j all && \
make install && \
rm -rf ${OMPI_TARBALL_FILENAME}/..
# Install TensorRT
FROM cuda-builder AS trt-builder
COPY backends/trtllm/scripts/install_tensorrt.sh /opt/install_tensorrt.sh
RUN chmod +x /opt/install_tensorrt.sh && \
/opt/install_tensorrt.sh
# Build Backend
FROM cuda-builder AS tgi-builder
WORKDIR /usr/src/text-generation-inference
# Scoped global args reuse
ARG cuda_arch_list
ARG build_type
ARG sccache_gha_enabled
ARG actions_cache_url
ARG actions_runtime_token
# Install Rust
ENV PATH="/root/.cargo/bin:$PATH"
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \
chmod -R a+w /root/.rustup && \
chmod -R a+w /root/.cargo && \
cargo install sccache --locked
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig"
ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt"
ENV USE_LLD_LINKER=ON
ENV CUDA_ARCH_LIST=${cuda_arch_list}
```

View File

@ -19,6 +19,6 @@ docker run --gpus all \
--shm-size 1g \ --shm-size 1g \
-e HF_TOKEN=$token \ -e HF_TOKEN=$token \
-p 8080:80 \ -p 8080:80 \
-v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.1 \ -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.0 \
--model-id $model --model-id $model
``` ```

View File

@ -19,7 +19,7 @@ bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models.
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇 In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
```bash ```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.1 --model-id $model --quantize bitsandbytes docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.0 --model-id $model --quantize bitsandbytes
``` ```
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load. 4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
@ -27,7 +27,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇 In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
```bash ```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.1 --model-id $model --quantize bitsandbytes-nf4 docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.0 --model-id $model --quantize bitsandbytes-nf4
``` ```
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes). You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
@ -48,7 +48,7 @@ $$({\hat{W}_{l}}^{*} = argmin_{\hat{W_{l}}} ||W_{l}X-\hat{W}_{l}X||^{2}_{2})$$
TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇 TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇
```bash ```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.1 --model-id $model --quantize gptq docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.0 --model-id $model --quantize gptq
``` ```
Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI. Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI.

View File

@ -27,7 +27,7 @@ You can check a few existing fine-tunes for popular models:
- [text-generation-inference/Mistral-7B-Instruct-v0.2-medusa](https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa) - [text-generation-inference/Mistral-7B-Instruct-v0.2-medusa](https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa)
In order to create your own medusa heads for your own finetune, you should check own the original medusa repo. [../basic_tutorials/train_medusa.md](../basic_tutorials/train_medusa.md) In order to create your own medusa heads for your own finetune, you should check own the original medusa repo. Read for more in [Train Medusa](../basic_tutorials/train_medusa#training).
In order to use medusa models in TGI, simply point to a medusa enabled model, and everything will load automatically. In order to use medusa models in TGI, simply point to a medusa enabled model, and everything will load automatically.

View File

@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--device=/dev/kfd --device=/dev/dri --group-add video \ --device=/dev/kfd --device=/dev/dri --group-add video \
--ipc=host --shm-size 256g --net host -v $volume:/data \ --ipc=host --shm-size 256g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.0.1-rocm \ ghcr.io/huggingface/text-generation-inference:3.1.0-rocm \
--model-id $model --model-id $model
``` ```

View File

@ -12,7 +12,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm --privileged --cap-add=sys_nice \ docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \ --device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \ --ipc=host --shm-size 1g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.0.1-intel-xpu \ ghcr.io/huggingface/text-generation-inference:3.1.0-intel-xpu \
--model-id $model --cuda-graphs 0 --model-id $model --cuda-graphs 0
``` ```
@ -29,7 +29,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm --privileged --cap-add=sys_nice \ docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \ --device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \ --ipc=host --shm-size 1g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.0.1-intel-cpu \ ghcr.io/huggingface/text-generation-inference:3.1.0-intel-cpu \
--model-id $model --cuda-graphs 0 --model-id $model --cuda-graphs 0
``` ```

View File

@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \ docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.0.1 \ ghcr.io/huggingface/text-generation-inference:3.1.0 \
--model-id $model --model-id $model
``` ```

View File

@ -0,0 +1,3 @@
# Using TGI with Google TPUs
Check out this [guide](https://huggingface.co/docs/optimum-tpu) on how to serve models with TGI on TPUs.

View File

@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.0.1 \ ghcr.io/huggingface/text-generation-inference:3.1.0 \
--model-id $model --model-id $model
``` ```
@ -96,7 +96,7 @@ curl 127.0.0.1:8080/generate \
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more. To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
```bash ```bash
docker run ghcr.io/huggingface/text-generation-inference:3.0.1 --help docker run ghcr.io/huggingface/text-generation-inference:3.1.0 --help
``` ```
</Tip> </Tip>

View File

@ -163,7 +163,7 @@ hub = {
# create Hugging Face Model Class # create Hugging Face Model Class
huggingface_model = HuggingFaceModel( huggingface_model = HuggingFaceModel(
image_uri=get_huggingface_llm_image_uri("huggingface",version="3.0.1"), image_uri=get_huggingface_llm_image_uri("huggingface",version="3.1.0"),
env=hub, env=hub,
role=role, role=role,
) )

View File

@ -4,6 +4,7 @@
Text Generation Inference enables serving optimized models. The following sections list which models (VLMs & LLMs) are supported. Text Generation Inference enables serving optimized models. The following sections list which models (VLMs & LLMs) are supported.
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2) - [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
- [Deepseek V3](https://huggingface.co/deepseek-ai/DeepSeek-V3)
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal) - [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
- [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (Multimodal) - [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (Multimodal)
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal) - [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)

View File

@ -3,7 +3,7 @@
Text Generation Inference collects anonymous usage statistics to help us improve the service. The collected data is used to improve TGI and to understand what causes failures. The data is collected transparently and any sensitive information is omitted. Text Generation Inference collects anonymous usage statistics to help us improve the service. The collected data is used to improve TGI and to understand what causes failures. The data is collected transparently and any sensitive information is omitted.
Data is sent twice, once on server startup and once when server stops. Also, usage statistics are only enabled when TGI is running in docker to avoid collecting data then TGI runs directly on the host machine. Usage statistics are collected only when TGI is running in a Docker container. This prevents data collection when TGI is run directly on the host machine. The collected data includes startup and shutdown events, as well as a heartbeat signal sent every 15 minutes.
## What data is collected ## What data is collected

View File

@ -108,11 +108,11 @@
"pre-commit-hooks": "pre-commit-hooks_3" "pre-commit-hooks": "pre-commit-hooks_3"
}, },
"locked": { "locked": {
"lastModified": 1732039290, "lastModified": 1734429562,
"narHash": "sha256-LQKY7bShf2H9kJouxa9ZspfdrulnZF9o4kLTqGqCDYM=", "narHash": "sha256-V2XNs3Ir8WXNHdocfzkR/fu0FzkZ9uTDJkVecxJrGmQ=",
"owner": "nix-community", "owner": "nix-community",
"repo": "crate2nix", "repo": "crate2nix",
"rev": "9ff208ce7f5a482272b1bcefbe363c772d7ff914", "rev": "8537c2d7cb623679aaeff62c4c4c43a91566ab09",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -305,11 +305,11 @@
}, },
"flake-compat_4": { "flake-compat_4": {
"locked": { "locked": {
"lastModified": 1696426674, "lastModified": 1733328505,
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=", "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
"owner": "edolstra", "owner": "edolstra",
"repo": "flake-compat", "repo": "flake-compat",
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33", "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -718,11 +718,11 @@
}, },
"nixpkgs_6": { "nixpkgs_6": {
"locked": { "locked": {
"lastModified": 1732034459, "lastModified": 1737453259,
"narHash": "sha256-Zais/zMRuJdlALidkUgEuasXOd37ZZLqkPkF9bIYSrY=", "narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=",
"owner": "danieldk", "owner": "danieldk",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "40280e7bf9743cdf563494db4ece2a43aa674fa8", "rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -853,11 +853,11 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1732242723, "lastModified": 1737685583,
"narHash": "sha256-NWI8csIK0ujFlFuEXKnoc+7hWoCiEtINK9r48LUUMeU=", "narHash": "sha256-p+NVABRpGi+pT+xxf9HcLcFVxG6L+vEEy+NwzB9T0f8=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "a229311fcb45b88a95fdfa5cecd8349c809a272a", "rev": "eb64cbcc8eee0fa87ebded92805280d2ec97415a",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -978,11 +978,11 @@
"nixpkgs": "nixpkgs_6" "nixpkgs": "nixpkgs_6"
}, },
"locked": { "locked": {
"lastModified": 1736436388, "lastModified": 1738323634,
"narHash": "sha256-CIyxVPpM9RrSwthNT/4DQ10YPk/uwzP7AeE83kBNsrE=", "narHash": "sha256-lKPzgEm7pEuQJVhacsxFHqg1MOtrUMZvr+9IuJzC5J4=",
"owner": "huggingface", "owner": "huggingface",
"repo": "text-generation-inference-nix", "repo": "text-generation-inference-nix",
"rev": "5103c3fb1f9ad1fd33b6e09ff05e957884b112d5", "rev": "eb5fede2756f544f75e01f55a4097f9c9a8c5005",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -562,6 +562,7 @@ def launcher(event_loop):
docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]]) docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
] ]
client.api.timeout = 1000
container = client.containers.run( container = client.containers.run(
DOCKER_IMAGE, DOCKER_IMAGE,
command=args, command=args,
@ -573,7 +574,7 @@ def launcher(event_loop):
devices=devices, devices=devices,
volumes=volumes, volumes=volumes,
ports={"80/tcp": port}, ports={"80/tcp": port},
healthcheck={"timeout": int(60 * 1e9), "retries": 2}, # 60s healthcheck={"timeout": int(180 * 1e9), "retries": 2}, # 60s
shm_size="1G", shm_size="1G",
) )

View File

@ -1,73 +1,469 @@
{ {
"details": { "details": {
"best_of_sequences": null, "best_of_sequences": null,
"finish_reason": "length", "finish_reason": "eos_token",
"generated_tokens": 10, "generated_tokens": 76,
"prefill": [], "prefill": [],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 18183, "id": 18183,
"logprob": -1.6669922, "logprob": -1.5195312,
"special": false, "special": false,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.08959961, "logprob": -0.06817627,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 374, "id": 374,
"logprob": -0.14685059, "logprob": -0.13122559,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.125, "logprob": -0.13415527,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 25993, "id": 25993,
"logprob": -0.81640625, "logprob": -0.8769531,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },
{ {
"id": 315, "id": 315,
"logprob": -0.0013418198, "logprob": -0.0011396408,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5662, "id": 5662,
"logprob": -0.16027832, "logprob": -0.16442871,
"special": false, "special": false,
"text": " machine" "text": " machine"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.0016393661, "logprob": -0.0026416779,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 429, "id": 429,
"logprob": -0.4477539, "logprob": -0.48754883,
"special": false, "special": false,
"text": " that" "text": " that"
}, },
{ {
"id": 5711, "id": 5711,
"logprob": -1.2802734, "logprob": -1.2294922,
"special": false, "special": false,
"text": " uses" "text": " uses"
},
{
"id": 29728,
"logprob": -0.66503906,
"special": false,
"text": " neural"
},
{
"id": 14155,
"logprob": -0.02960205,
"special": false,
"text": " networks"
},
{
"id": 311,
"logprob": -0.7236328,
"special": false,
"text": " to"
},
{
"id": 3960,
"logprob": -1.1914062,
"special": false,
"text": " learn"
},
{
"id": 504,
"logprob": -0.7089844,
"special": false,
"text": " from"
},
{
"id": 821,
"logprob": -0.7729492,
"special": false,
"text": " data"
},
{
"id": 13,
"logprob": -0.7836914,
"special": false,
"text": "."
},
{
"id": 1084,
"logprob": -0.9941406,
"special": false,
"text": " It"
},
{
"id": 374,
"logprob": -0.52441406,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.9511719,
"special": false,
"text": " a"
},
{
"id": 943,
"logprob": -0.8642578,
"special": false,
"text": " type"
},
{
"id": 315,
"logprob": -0.00030231476,
"special": false,
"text": " of"
},
{
"id": 20443,
"logprob": -0.14416504,
"special": false,
"text": " artificial"
},
{
"id": 11229,
"logprob": -0.013824463,
"special": false,
"text": " intelligence"
},
{
"id": 429,
"logprob": -0.18762207,
"special": false,
"text": " that"
},
{
"id": 646,
"logprob": -1.0087891,
"special": false,
"text": " can"
},
{
"id": 3960,
"logprob": -0.90234375,
"special": false,
"text": " learn"
},
{
"id": 504,
"logprob": -0.54345703,
"special": false,
"text": " from"
},
{
"id": 323,
"logprob": -1.0400391,
"special": false,
"text": " and"
},
{
"id": 1281,
"logprob": -0.072509766,
"special": false,
"text": " make"
},
{
"id": 19898,
"logprob": -0.16516113,
"special": false,
"text": " predictions"
},
{
"id": 389,
"logprob": -0.4416504,
"special": false,
"text": " on"
},
{
"id": 3460,
"logprob": -0.5385742,
"special": false,
"text": " large"
},
{
"id": 14713,
"logprob": -0.4387207,
"special": false,
"text": " amounts"
},
{
"id": 315,
"logprob": -0.00015091896,
"special": false,
"text": " of"
},
{
"id": 821,
"logprob": -0.061431885,
"special": false,
"text": " data"
},
{
"id": 13,
"logprob": -0.71875,
"special": false,
"text": "."
},
{
"id": 18183,
"logprob": -0.23632812,
"special": false,
"text": " Deep"
},
{
"id": 6832,
"logprob": -0.0017204285,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -1.1738281,
"special": false,
"text": " is"
},
{
"id": 1483,
"logprob": -0.61083984,
"special": false,
"text": " used"
},
{
"id": 304,
"logprob": -0.035003662,
"special": false,
"text": " in"
},
{
"id": 264,
"logprob": -0.118652344,
"special": false,
"text": " a"
},
{
"id": 8045,
"logprob": -0.42016602,
"special": false,
"text": " variety"
},
{
"id": 315,
"logprob": -1.6212463e-05,
"special": false,
"text": " of"
},
{
"id": 8357,
"logprob": -0.1315918,
"special": false,
"text": " applications"
},
{
"id": 11,
"logprob": -0.12915039,
"special": false,
"text": ","
},
{
"id": 2670,
"logprob": -0.12463379,
"special": false,
"text": " including"
},
{
"id": 2168,
"logprob": -0.37402344,
"special": false,
"text": " image"
},
{
"id": 323,
"logprob": -0.1451416,
"special": false,
"text": " and"
},
{
"id": 8806,
"logprob": -0.028869629,
"special": false,
"text": " speech"
},
{
"id": 17843,
"logprob": -0.00024068356,
"special": false,
"text": " recognition"
},
{
"id": 11,
"logprob": -0.00031018257,
"special": false,
"text": ","
},
{
"id": 5810,
"logprob": -0.019821167,
"special": false,
"text": " natural"
},
{
"id": 4128,
"logprob": -0.00012528896,
"special": false,
"text": " language"
},
{
"id": 8692,
"logprob": -0.00089263916,
"special": false,
"text": " processing"
},
{
"id": 11,
"logprob": -0.00073862076,
"special": false,
"text": ","
},
{
"id": 323,
"logprob": -0.040161133,
"special": false,
"text": " and"
},
{
"id": 38193,
"logprob": -0.4519043,
"special": false,
"text": " autonomous"
},
{
"id": 11474,
"logprob": -0.39941406,
"special": false,
"text": " vehicles"
},
{
"id": 13,
"logprob": -0.21166992,
"special": false,
"text": "."
},
{
"id": 1084,
"logprob": -0.9082031,
"special": false,
"text": " It"
},
{
"id": 374,
"logprob": -0.44213867,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -1.2177734,
"special": false,
"text": " a"
},
{
"id": 18512,
"logprob": -0.5205078,
"special": false,
"text": " rapidly"
},
{
"id": 7826,
"logprob": -0.15332031,
"special": false,
"text": " growing"
},
{
"id": 2070,
"logprob": -0.0039978027,
"special": false,
"text": " field"
},
{
"id": 448,
"logprob": -0.9091797,
"special": false,
"text": " with"
},
{
"id": 1657,
"logprob": -0.17114258,
"special": false,
"text": " many"
},
{
"id": 4650,
"logprob": -0.70703125,
"special": false,
"text": " potential"
},
{
"id": 8357,
"logprob": -0.025131226,
"special": false,
"text": " applications"
},
{
"id": 304,
"logprob": -0.6699219,
"special": false,
"text": " in"
},
{
"id": 279,
"logprob": -0.35205078,
"special": false,
"text": " the"
},
{
"id": 3853,
"logprob": -0.049194336,
"special": false,
"text": " future"
},
{
"id": 13,
"logprob": -0.21972656,
"special": false,
"text": "."
},
{
"id": 151643,
"logprob": -2.0019531,
"special": true,
"text": "<|endoftext|>"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": " Deep learning is a subset of machine learning that uses" "generated_text": " Deep learning is a subset of machine learning that uses neural networks to learn from data. It is a type of artificial intelligence that can learn from and make predictions on large amounts of data. Deep learning is used in a variety of applications, including image and speech recognition, natural language processing, and autonomous vehicles. It is a rapidly growing field with many potential applications in the future."
} }

View File

@ -7,67 +7,67 @@
"seed": 0, "seed": 0,
"tokens": [ "tokens": [
{ {
"id": 1939, "id": 5267,
"logprob": -2.2460938, "logprob": -1.1464844,
"special": false, "special": false,
"text": "?\n\n" "text": "?\n"
}, },
{ {
"id": 33464, "id": 33464,
"logprob": 0.0, "logprob": -0.83203125,
"special": false, "special": false,
"text": "Deep" "text": "Deep"
}, },
{ {
"id": 20909, "id": 20909,
"logprob": -0.48608398, "logprob": -0.5625,
"special": false, "special": false,
"text": " Learning" "text": " Learning"
}, },
{
"id": 4102,
"logprob": -2.265625,
"special": false,
"text": " "
},
{
"id": 285,
"logprob": 0.0,
"special": false,
"text": "is"
},
{
"id": 458,
"logprob": -0.6328125,
"special": false,
"text": " an"
},
{
"id": 20443,
"logprob": -0.1796875,
"special": false,
"text": " artificial"
},
{
"id": 11229,
"logprob": 0.0,
"special": false,
"text": " intelligence"
},
{ {
"id": 320, "id": 320,
"logprob": -0.37695312, "logprob": -2.1464844,
"special": false, "special": false,
"text": " (" "text": " ("
}, },
{ {
"id": 15469, "id": 16524,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": "AI" "text": "DL"
},
{
"id": 701,
"logprob": -2.2089844,
"special": false,
"text": "),"
},
{
"id": 476,
"logprob": -0.27368164,
"special": false,
"text": " or"
},
{
"id": 20443,
"logprob": -0.09442139,
"special": false,
"text": " artificial"
},
{
"id": 29728,
"logprob": 0.0,
"special": false,
"text": " neural"
},
{
"id": 14155,
"logprob": 0.0,
"special": false,
"text": " networks"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "What is deep learning?\n\nDeep Learning is an artificial intelligence (AI" "generated_text": "What is deep learning?\nDeep Learning (DL), or artificial neural networks"
} }

View File

@ -9,61 +9,61 @@
"tokens": [ "tokens": [
{ {
"id": 18183, "id": 18183,
"logprob": -1.4912109, "logprob": -1.5195312,
"special": false, "special": false,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.075683594, "logprob": -0.06817627,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 374, "id": 374,
"logprob": -0.12408447, "logprob": -0.13122559,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.12768555, "logprob": -0.13415527,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 25993, "id": 25993,
"logprob": -0.82128906, "logprob": -0.87353516,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },
{ {
"id": 315, "id": 315,
"logprob": -0.0012636185, "logprob": -0.0011396408,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5662, "id": 5662,
"logprob": -0.12878418, "logprob": -0.16442871,
"special": false, "special": false,
"text": " machine" "text": " machine"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.0015888214, "logprob": -0.0026416779,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 429, "id": 429,
"logprob": -0.49194336, "logprob": -0.48754883,
"special": false, "special": false,
"text": " that" "text": " that"
}, },
{ {
"id": 5711, "id": 5711,
"logprob": -1.2626953, "logprob": -1.2294922,
"special": false, "special": false,
"text": " uses" "text": " uses"
} }
@ -82,61 +82,61 @@
"tokens": [ "tokens": [
{ {
"id": 18183, "id": 18183,
"logprob": -1.4912109, "logprob": -1.5195312,
"special": false, "special": false,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.075683594, "logprob": -0.06817627,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 374, "id": 374,
"logprob": -0.12408447, "logprob": -0.13122559,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.12768555, "logprob": -0.13415527,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 25993, "id": 25993,
"logprob": -0.82128906, "logprob": -0.87353516,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },
{ {
"id": 315, "id": 315,
"logprob": -0.0012636185, "logprob": -0.0011396408,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5662, "id": 5662,
"logprob": -0.12878418, "logprob": -0.16442871,
"special": false, "special": false,
"text": " machine" "text": " machine"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.0015888214, "logprob": -0.0026416779,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 429, "id": 429,
"logprob": -0.49194336, "logprob": -0.48754883,
"special": false, "special": false,
"text": " that" "text": " that"
}, },
{ {
"id": 5711, "id": 5711,
"logprob": -1.2626953, "logprob": -1.2294922,
"special": false, "special": false,
"text": " uses" "text": " uses"
} }
@ -155,61 +155,61 @@
"tokens": [ "tokens": [
{ {
"id": 18183, "id": 18183,
"logprob": -1.4912109, "logprob": -1.5195312,
"special": false, "special": false,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.075683594, "logprob": -0.06817627,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 374, "id": 374,
"logprob": -0.12408447, "logprob": -0.13122559,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.12768555, "logprob": -0.13415527,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 25993, "id": 25993,
"logprob": -0.82128906, "logprob": -0.87353516,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },
{ {
"id": 315, "id": 315,
"logprob": -0.0012636185, "logprob": -0.0011396408,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5662, "id": 5662,
"logprob": -0.12878418, "logprob": -0.16442871,
"special": false, "special": false,
"text": " machine" "text": " machine"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.0015888214, "logprob": -0.0026416779,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 429, "id": 429,
"logprob": -0.49194336, "logprob": -0.48754883,
"special": false, "special": false,
"text": " that" "text": " that"
}, },
{ {
"id": 5711, "id": 5711,
"logprob": -1.2626953, "logprob": -1.2294922,
"special": false, "special": false,
"text": " uses" "text": " uses"
} }
@ -228,61 +228,61 @@
"tokens": [ "tokens": [
{ {
"id": 18183, "id": 18183,
"logprob": -1.4912109, "logprob": -1.5195312,
"special": false, "special": false,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.075683594, "logprob": -0.06817627,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 374, "id": 374,
"logprob": -0.12408447, "logprob": -0.13122559,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.12768555, "logprob": -0.13415527,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 25993, "id": 25993,
"logprob": -0.82128906, "logprob": -0.87353516,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },
{ {
"id": 315, "id": 315,
"logprob": -0.0012636185, "logprob": -0.0011396408,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5662, "id": 5662,
"logprob": -0.12878418, "logprob": -0.16442871,
"special": false, "special": false,
"text": " machine" "text": " machine"
}, },
{ {
"id": 6832, "id": 6832,
"logprob": -0.0015888214, "logprob": -0.0026416779,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 429, "id": 429,
"logprob": -0.49194336, "logprob": -0.48754883,
"special": false, "special": false,
"text": " that" "text": " that"
}, },
{ {
"id": 5711, "id": 5711,
"logprob": -1.2626953, "logprob": -1.2294922,
"special": false, "special": false,
"text": " uses" "text": " uses"
} }

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "The image showcases a stunning cityscape, featuring the iconic Statue of Liberty in the foreground. The image displays Lady Liberty's imposing presence, with her towering base standing beside her. Behind the statue, the city's skyline extends across the horizon, adorned with numerous tall buildings, including the Empire State Building and other notable skyscrapers. The water reflecting the sun's rays creates a serene and picturesque scene, emphasizing the beauty and resilience of this global landmark. The sky is a clear, pale blue, adding to the overall tranquility of the scene.",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1738348090,
"id": "",
"model": "Qwen/Qwen2-VL-7B-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.1.1-dev0-native",
"usage": {
"completion_tokens": 110,
"prompt_tokens": 8736,
"total_tokens": 8846
}
}

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "The image shows a stylized scene set in what appears to be a diner or restaurant. In the foreground, there is a table with various food items, including a burger with lettuce and tomato, a bowl of fries, and a drink in a cup with a straw. On the right side of the table, there is an owl sitting alertly, looking directly at the camera. Behind the owl and the table, there is a large, green, dinosaur-like creature resembling Godzilla, with its mouth open and tongue visible. In the background, the diner's decor includes various signs and posters, with a green sign reading \"Basta\" and another sign that says \"Tabasco.\" The setting has a retro or vintage feel, with fluorescent lighting overhead and clean, polished surfaces.",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1738348100,
"id": "",
"model": "Qwen/Qwen2-VL-7B-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.1.1-dev0-native",
"usage": {
"completion_tokens": 156,
"prompt_tokens": 5375,
"total_tokens": 5531
}
}

View File

@ -5,7 +5,7 @@
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"message": { "message": {
"content": "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape.", "content": "The image depicts an anthropomorphic rabbit, wearing a spacesuit, standing in a barren, rocky landscape that resembles the surface of another planet, possibly Mars. The rabbit has a red digestive system label on its chest, and the surrounding environment features red sandy terrain and a hazy, floating planet or moon in the background. The scene has a surreal, fantastical quality, blending elements of science fiction and space exploration with a whimsical character.",
"name": null, "name": null,
"role": "assistant", "role": "assistant",
"tool_calls": null "tool_calls": null
@ -13,14 +13,14 @@
"usage": null "usage": null
} }
], ],
"created": 1730164250, "created": 1738347908,
"id": "", "id": "",
"model": "Qwen/Qwen2-VL-7B-Instruct", "model": "Qwen/Qwen2-VL-7B-Instruct",
"object": "chat.completion", "object": "chat.completion",
"system_fingerprint": "2.4.2-dev0-native", "system_fingerprint": "3.1.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 58, "completion_tokens": 89,
"prompt_tokens": 349, "prompt_tokens": 1364,
"total_tokens": 407 "total_tokens": 1453
} }
} }

View File

@ -11,10 +11,10 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1730416361, "created": 1737646031,
"id": "", "id": "",
"model": "Qwen/Qwen2-VL-7B-Instruct", "model": "Qwen/Qwen2-VL-7B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"system_fingerprint": "2.4.2-dev0-native", "system_fingerprint": "3.0.2-dev0-native",
"usage": null "usage": null
} }

View File

@ -0,0 +1,73 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 2284,
"logprob": -0.9355469,
"special": false,
"text": "():"
},
{
"id": 303,
"logprob": -0.40795898,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": -0.27954102,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": -0.6142578,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": -0.68310547,
"special": false,
"text": "Hello"
},
{
"id": 10914,
"logprob": -1.4599609,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.80126953,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.625,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -0.23242188,
"special": false,
"text": "\n"
},
{
"id": 610,
"logprob": -1.2294922,
"special": false,
"text": "def"
}
],
"top_tokens": null
},
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
}

View File

@ -0,0 +1,373 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 60,
"prefill": [],
"seed": 0,
"tokens": [
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -0.7944336,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -0.1796875,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": 0.0,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": 0.0,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": 0.0,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": 0.0,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": 0.0,
"special": false,
"text": "slide"
},
{
"id": 100,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 700,
"logprob": 0.0,
"special": false,
"text": "type"
},
{
"id": 582,
"logprob": 0.0,
"special": false,
"text": "\":"
},
{
"id": 332,
"logprob": 0.0,
"special": false,
"text": " \""
},
{
"id": 7277,
"logprob": -0.06994629,
"special": false,
"text": "slide"
},
{
"id": 3667,
"logprob": 0.0,
"special": false,
"text": "\"}"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": 0.0,
"special": false,
"text": "#"
},
{
"id": 607,
"logprob": -0.8261719,
"special": false,
"text": " #"
},
{
"id": 244,
"logprob": -1.8574219,
"special": false,
"text": " "
},
{
"id": 55,
"logprob": -1.4541016,
"special": false,
"text": "2"
},
{
"id": 51,
"logprob": 0.0,
"special": false,
"text": "."
},
{
"id": 6208,
"logprob": -0.9794922,
"special": false,
"text": " What"
},
{
"id": 458,
"logprob": 0.0,
"special": false,
"text": " is"
},
{
"id": 341,
"logprob": 0.0,
"special": false,
"text": " the"
},
{
"id": 10609,
"logprob": -0.69189453,
"special": false,
"text": " difference"
},
{
"id": 3761,
"logprob": 0.0,
"special": false,
"text": " between"
},
{
"id": 331,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 1168,
"logprob": -0.27172852,
"special": false,
"text": " list"
},
{
"id": 480,
"logprob": 0.0,
"special": false,
"text": " and"
},
{
"id": 331,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 8871,
"logprob": 0.0,
"special": false,
"text": " tuple"
},
{
"id": 68,
"logprob": 0.0,
"special": false,
"text": "?"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -1.3359375,
"special": false,
"text": "#"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": 0.0,
"special": false,
"text": "#"
},
{
"id": 449,
"logprob": -0.03164673,
"special": false,
"text": " -"
},
{
"id": 418,
"logprob": -1.0947266,
"special": false,
"text": " A"
},
{
"id": 1168,
"logprob": 0.0,
"special": false,
"text": " list"
},
{
"id": 458,
"logprob": 0.0,
"special": false,
"text": " is"
},
{
"id": 331,
"logprob": -0.3305664,
"special": false,
"text": " a"
},
{
"id": 14792,
"logprob": 0.0,
"special": false,
"text": " mutable"
},
{
"id": 6645,
"logprob": -0.40478516,
"special": false,
"text": " sequence"
},
{
"id": 451,
"logprob": 0.0,
"special": false,
"text": " of"
},
{
"id": 4725,
"logprob": -0.50390625,
"special": false,
"text": " elements"
},
{
"id": 49,
"logprob": -2.1269531,
"special": false,
"text": ","
},
{
"id": 2236,
"logprob": -0.1427002,
"special": false,
"text": " while"
},
{
"id": 331,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 8871,
"logprob": 0.0,
"special": false,
"text": " tuple"
},
{
"id": 458,
"logprob": 0.0,
"special": false,
"text": " is"
},
{
"id": 619,
"logprob": 0.0,
"special": false,
"text": " an"
},
{
"id": 26079,
"logprob": 0.0,
"special": false,
"text": " immutable"
},
{
"id": 6645,
"logprob": 0.0,
"special": false,
"text": " sequence"
},
{
"id": 451,
"logprob": 0.0,
"special": false,
"text": " of"
},
{
"id": 4725,
"logprob": 0.0,
"special": false,
"text": " elements"
},
{
"id": 51,
"logprob": 0.0,
"special": false,
"text": "."
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": 0.0,
"special": false,
"text": "#"
},
{
"id": 449,
"logprob": 0.0,
"special": false,
"text": " -"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide_type\": \"slide\"}\n# # 2. What is the difference between a list and a tuple?\n#\n# - A list is a mutable sequence of elements, while a tuple is an immutable sequence of elements.\n# -"
}

View File

@ -0,0 +1,294 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 222,
"logprob": -1.9091797,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -1.0478516,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -3.015625,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": -1.4228516,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -1.1025391,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": -0.0008444786,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": -8.8095665e-05,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": -0.5810547,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": -0.00022864342,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": -0.00030994415,
"special": false,
"text": "slide"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 222,
"logprob": -1.9091797,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -1.0478516,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -3.015625,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": -1.4228516,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -1.1025391,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": -0.0008444786,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": -8.8095665e-05,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": -0.5810547,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": -0.00022864342,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": -0.00030994415,
"special": false,
"text": "slide"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 222,
"logprob": -1.9091797,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -1.0478516,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -3.015625,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": -1.4228516,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -1.1025391,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": -0.0008444786,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": -8.8095665e-05,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": -0.5810547,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": -0.00022864342,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": -0.00030994415,
"special": false,
"text": "slide"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 222,
"logprob": -1.9091797,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -1.0478516,
"special": false,
"text": "\n"
},
{
"id": 40,
"logprob": -3.015625,
"special": false,
"text": "#"
},
{
"id": 494,
"logprob": -1.4228516,
"special": false,
"text": " +"
},
{
"id": 447,
"logprob": -1.1025391,
"special": false,
"text": " ["
},
{
"id": 9009,
"logprob": -0.0008444786,
"special": false,
"text": "markdown"
},
{
"id": 98,
"logprob": -8.8095665e-05,
"special": false,
"text": "]"
},
{
"id": 37402,
"logprob": -0.5810547,
"special": false,
"text": " slideshow"
},
{
"id": 8492,
"logprob": -0.00022864342,
"special": false,
"text": "={\""
},
{
"id": 7277,
"logprob": -0.00030994415,
"special": false,
"text": "slide"
}
],
"top_tokens": null
},
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
}
]

View File

@ -0,0 +1,71 @@
{
"details": {
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 100,
"logprob": -0.9824219,
"special": false,
"text": "_"
},
{
"id": 5879,
"logprob": -0.3017578,
"special": false,
"text": "world"
},
{
"id": 2284,
"logprob": -0.68652344,
"special": false,
"text": "():"
},
{
"id": 303,
"logprob": -0.27734375,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": -0.4482422,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": -0.54248047,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": -0.4296875,
"special": false,
"text": "Hello"
},
{
"id": 10914,
"logprob": -0.8544922,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.7573242,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.81347656,
"special": false,
"text": "\n"
}
]
},
"generated_text": "_world():\n print(\"Hello World!\")\n"
}

View File

@ -6,7 +6,7 @@
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"message": { "message": {
"content": "In a bustling city, a chicken named Cluck", "content": "In a small town, a chicken named Cluck",
"name": null, "name": null,
"role": "assistant", "role": "assistant",
"tool_calls": null "tool_calls": null
@ -14,11 +14,11 @@
"usage": null "usage": null
} }
], ],
"created": 1727773835, "created": 1738753835,
"id": "", "id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct", "model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion", "object": "chat.completion",
"system_fingerprint": "2.4.2-dev0-native", "system_fingerprint": "3.1.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 10, "completion_tokens": 10,
"prompt_tokens": 50, "prompt_tokens": 50,
@ -32,7 +32,7 @@
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"message": { "message": {
"content": "In a world where even chickens could dream big,", "content": "In a small town, a chicken named Cluck",
"name": null, "name": null,
"role": "assistant", "role": "assistant",
"tool_calls": null "tool_calls": null
@ -40,63 +40,11 @@
"usage": null "usage": null
} }
], ],
"created": 1727773835, "created": 1738753835,
"id": "", "id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct", "model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion", "object": "chat.completion",
"system_fingerprint": "2.4.2-dev0-native", "system_fingerprint": "3.1.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60
}
},
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a world where even chickens could dream big,",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727773835,
"id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.4.2-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60
}
},
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a world where even chickens could dream big,",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727773835,
"id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.4.2-dev0-native",
"usage": { "usage": {
"completion_tokens": 10, "completion_tokens": 10,
"prompt_tokens": 50, "prompt_tokens": 50,

View File

@ -5,7 +5,7 @@
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"message": { "message": {
"content": "In a bustling city, a chicken named Cluck", "content": "In a small town, a chicken named Cluck",
"name": null, "name": null,
"role": "assistant", "role": "assistant",
"tool_calls": null "tool_calls": null
@ -13,11 +13,11 @@
"usage": null "usage": null
} }
], ],
"created": 1727556016, "created": 1738753833,
"id": "", "id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct", "model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion", "object": "chat.completion",
"system_fingerprint": "2.4.2-dev0-native", "system_fingerprint": "3.1.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 10, "completion_tokens": 10,
"prompt_tokens": 50, "prompt_tokens": 50,

View File

@ -27,15 +27,16 @@ async def test_compressed_tensors_w8a8_int_dynamic_weight(
): ):
response = await compressed_tensors_w8a8_int_dynamic_weight.generate( response = await compressed_tensors_w8a8_int_dynamic_weight.generate(
"What is deep learning?", "What is deep learning?",
max_new_tokens=10, # prefer a longer response than the default, allow the llm to end generation
max_new_tokens=1000,
decoder_input_details=True, decoder_input_details=True,
) )
assert ( assert (
response.generated_text response.generated_text
== " Deep learning is a subset of machine learning that uses" == " Deep learning is a subset of machine learning that uses neural networks to learn from data. It is a type of artificial intelligence that can learn from and make predictions on large amounts of data. Deep learning is used in a variety of applications, including image and speech recognition, natural language processing, and autonomous vehicles. It is a rapidly growing field with many potential applications in the future."
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 76
assert response == response_snapshot assert response == response_snapshot
@ -64,7 +65,7 @@ async def test_compressed_tensors_w8a8_int_dynamic_weight_all_params(
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert ( assert (
response.generated_text response.generated_text
== "What is deep learning?\n\nDeep Learning is an artificial intelligence (AI" == "What is deep learning?\nDeep Learning (DL), or artificial neural networks"
) )
assert response == response_snapshot assert response == response_snapshot

View File

@ -1,81 +1,122 @@
# Disabled because it's broken. import pytest
# import pytest
#
# @pytest.fixture(scope="module")
# @pytest.fixture(scope="module") def flash_qwen2_vl_handle(launcher):
# def flash_qwen2_vl_handle(launcher): with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle:
# with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle: yield handle
# yield handle
#
# @pytest.fixture(scope="module")
# @pytest.fixture(scope="module") async def flash_qwen2(flash_qwen2_vl_handle):
# async def flash_qwen2(flash_qwen2_vl_handle): await flash_qwen2_vl_handle.health(300)
# await flash_qwen2_vl_handle.health(300) return flash_qwen2_vl_handle.client
# return flash_qwen2_vl_handle.client
#
# @pytest.mark.private
# @pytest.mark.private async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):
# async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot): response = await flash_qwen2.chat(
# response = await flash_qwen2.chat( seed=42,
# max_tokens=100, messages=[
# seed=42, {
# messages=[ "role": "user",
# { "content": [
# "role": "user", {
# "content": [ "type": "image_url",
# { "image_url": {
# "type": "image_url", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
# "image_url": { },
# "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" },
# }, {"type": "text", "text": "Describe this image."},
# }, ],
# {"type": "text", "text": "Describe this image."}, },
# ], ],
# }, )
# ],
# ) assert (
# response.choices[0].message.content
# assert ( == "The image depicts an anthropomorphic rabbit, wearing a spacesuit, standing in a barren, rocky landscape that resembles the surface of another planet, possibly Mars. The rabbit has a red digestive system label on its chest, and the surrounding environment features red sandy terrain and a hazy, floating planet or moon in the background. The scene has a surreal, fantastical quality, blending elements of science fiction and space exploration with a whimsical character."
# response.choices[0].message.content )
# == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
# ) assert response == response_snapshot
#
# assert response == response_snapshot
# @pytest.mark.private
# async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot):
# @pytest.mark.private responses = await flash_qwen2.chat(
# async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot): seed=42,
# responses = await flash_qwen2.chat( messages=[
# max_tokens=100, {
# seed=42, "role": "user",
# messages=[ "content": [
# { {
# "role": "user", "type": "image_url",
# "content": [ "image_url": {
# { "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
# "type": "image_url", },
# "image_url": { },
# "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" {"type": "text", "text": "Describe this image."},
# }, ],
# }, },
# {"type": "text", "text": "Describe this image."}, ],
# ], stream=True,
# }, )
# ],
# stream=True, count = 0
# ) generated = ""
# last_response = None
# count = 0 async for response in responses:
# generated = "" count += 1
# last_response = None generated += response.choices[0].delta.content
# async for response in responses: last_response = response
# count += 1
# generated += response.choices[0].delta.content assert (
# last_response = response generated
# == "The image depicts an anthropomorphic rabbit, wearing a spacesuit, standing in a barren, rocky landscape that resembles the surface of another planet, possibly Mars. The rabbit has a red digestive system label on its chest, and the surrounding environment features red sandy terrain and a hazy, floating planet or moon in the background. The scene has a surreal, fantastical quality, blending elements of science fiction and space exploration with a whimsical character."
# assert ( )
# generated assert count == 89
# == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape." assert last_response == response_snapshot
# )
# assert count == 58
# assert last_response == response_snapshot @pytest.mark.private
async def test_flash_qwen2_vl_bay(flash_qwen2, response_snapshot):
response = await flash_qwen2.chat(
seed=42,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
},
},
{"type": "text", "text": "Describe the image"},
],
},
],
)
assert response == response_snapshot
@pytest.mark.private
async def test_flash_qwen2_vl_inpaint(flash_qwen2, response_snapshot):
response = await flash_qwen2.chat(
seed=42,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-inpaint.png"
},
},
{"type": "text", "text": "Describe the image"},
],
},
],
)
assert response == response_snapshot

View File

@ -0,0 +1,79 @@
import pytest
import requests
@pytest.fixture(scope="module")
def flash_starcoder2_handle(launcher):
with launcher(
"bigcode/starcoder2-3b", lora_adapters=["smangrul/starcoder-3b-hugcoder"]
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_starcoder2(flash_starcoder2_handle):
await flash_starcoder2_handle.health(300)
return flash_starcoder2_handle.client
@pytest.mark.asyncio
async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
response = await flash_starcoder2.generate(
"def print_hello", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot):
response = await flash_starcoder2.generate(
"who are you?",
max_new_tokens=60,
temperature=0.2,
top_p=0.95,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 60
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_starcoder2_load(
flash_starcoder2, generate_load, response_snapshot
):
responses = await generate_load(
flash_starcoder2, "who are you?", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot
@pytest.mark.asyncio
async def test_flash_starcoder2_with_hugcode_adapter(
flash_starcoder2, response_snapshot
):
response = requests.post(
f"{flash_starcoder2.base_url}/generate",
headers=flash_starcoder2.headers,
json={
"inputs": "def print_hello",
"parameters": {
"max_new_tokens": 10,
"adapter_id": "smangrul/starcoder-3b-hugcoder",
"details": True,
},
},
)
assert response.status_code == 200
data = response.json()
assert data["generated_text"] == '_world():\n print("Hello World!")\n'
assert data == response_snapshot

View File

@ -25,21 +25,23 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap
assert response == generous_response_snapshot assert response == generous_response_snapshot
@pytest.mark.release # Deactivated because it's flaky
@pytest.mark.asyncio # Only this model seems affected and it's only a logprob precision issue.
async def test_flash_starcoder_gptq_default_params( # @pytest.mark.release
flash_starcoder_gptq, generous_response_snapshot # @pytest.mark.asyncio
): # async def test_flash_starcoder_gptq_default_params(
response = await flash_starcoder_gptq.generate( # flash_starcoder_gptq, generous_response_snapshot
"def geometric_mean(L: List[float]):", # ):
max_new_tokens=20, # response = await flash_starcoder_gptq.generate(
temperature=0.2, # "def geometric_mean(L: List[float]):",
top_p=0.95, # max_new_tokens=20,
decoder_input_details=True, # temperature=0.2,
seed=0, # top_p=0.95,
) # decoder_input_details=True,
assert response.details.generated_tokens == 2 # seed=0,
assert response == generous_response_snapshot # )
# assert response.details.generated_tokens == 2
# assert response == generous_response_snapshot
@pytest.mark.release @pytest.mark.release

View File

@ -47,8 +47,7 @@ async def test_mllama_simpl(mllama, response_snapshot):
"total_tokens": 60, "total_tokens": 60,
} }
assert ( assert (
response.choices[0].message.content response.choices[0].message.content == "In a small town, a chicken named Cluck"
== "In a bustling city, a chicken named Cluck"
) )
assert response == response_snapshot assert response == response_snapshot
@ -84,12 +83,12 @@ async def test_mllama_load(mllama, generate_load, response_snapshot):
] ]
responses = await asyncio.gather(*futures) responses = await asyncio.gather(*futures)
_ = [response.choices[0].message.content for response in responses] generated_texts = [response.choices[0].message.content for response in responses]
# XXX: TODO: Fix this test. # XXX: TODO: Fix this test.
# assert generated_texts[0] == "In a bustling city, a chicken named Cluck" assert generated_texts[0] == "In a small town, a chicken named Cluck"
# assert len(generated_texts) == 4 assert len(generated_texts) == 2
# assert generated_texts, all( assert generated_texts, all(
# [text == generated_texts[0] for text in generated_texts] [text == generated_texts[0] for text in generated_texts]
# ) )
# assert responses == response_snapshot assert responses == response_snapshot

View File

@ -5,7 +5,6 @@ use hf_hub::{
}; };
use nix::sys::signal::{self, Signal}; use nix::sys::signal::{self, Signal};
use nix::unistd::Pid; use nix::unistd::Pid;
use regex::Regex;
use serde::Deserialize; use serde::Deserialize;
use std::env; use std::env;
use std::ffi::OsString; use std::ffi::OsString;
@ -144,7 +143,9 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
} }
} }
let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) { let fallback_attention = if compute_capability.is_none()
|| matches!(compute_capability, Some((major, _)) if major < 8)
{
"paged" "paged"
} else { } else {
"flashdecoding" "flashdecoding"
@ -1631,8 +1632,10 @@ enum Gpu {
L40, L40,
L40S, L40S,
A10G, A10G,
A40,
H100, H100,
A100, A100,
H200,
Unknown(String), Unknown(String),
} }
@ -1651,6 +1654,7 @@ impl From<&str> for Gpu {
"nvidia-l40" => Gpu::L40, "nvidia-l40" => Gpu::L40,
"nvidia-l40s" => Gpu::L40S, "nvidia-l40s" => Gpu::L40S,
"nvidia-a10g" => Gpu::A10G, "nvidia-a10g" => Gpu::A10G,
"nvidia-a40" => Gpu::A40,
"nvidia-h100-80gb-hbm3" => Gpu::H100, "nvidia-h100-80gb-hbm3" => Gpu::H100,
"nvidia-h100-nvl" => Gpu::H100, "nvidia-h100-nvl" => Gpu::H100,
"nvidia-h100" => Gpu::H100, "nvidia-h100" => Gpu::H100,
@ -1658,6 +1662,7 @@ impl From<&str> for Gpu {
"nvidia-a100-sxm4-40gb" => Gpu::A100, "nvidia-a100-sxm4-40gb" => Gpu::A100,
"nvidia-a100-80gb-pcie" => Gpu::A100, "nvidia-a100-80gb-pcie" => Gpu::A100,
"nvidia-a100" => Gpu::A100, "nvidia-a100" => Gpu::A100,
"nvidia-h200" => Gpu::H200,
card => Gpu::Unknown(card.to_string()), card => Gpu::Unknown(card.to_string()),
} }
} }
@ -1672,8 +1677,10 @@ impl std::fmt::Display for Gpu {
Gpu::L40 => write!(f, "nvida-l40"), Gpu::L40 => write!(f, "nvida-l40"),
Gpu::L40S => write!(f, "nvida-l40s"), Gpu::L40S => write!(f, "nvida-l40s"),
Gpu::A10G => write!(f, "nvidia-a10g"), Gpu::A10G => write!(f, "nvidia-a10g"),
Gpu::A40 => write!(f, "nvidia-a40"),
Gpu::H100 => write!(f, "nvidia-h100-80fb-hbm3"), Gpu::H100 => write!(f, "nvidia-h100-80fb-hbm3"),
Gpu::A100 => write!(f, "nvida-a100-sxm4-80gb"), Gpu::A100 => write!(f, "nvida-a100-sxm4-80gb"),
Gpu::H200 => write!(f, "nvida-h200"),
Gpu::Unknown(card) => write!(f, "{}", card), Gpu::Unknown(card) => write!(f, "{}", card),
} }
} }
@ -1695,11 +1702,16 @@ impl ComputeType {
Gpu::L40S => Some(363 * 10u64.pow(12)), Gpu::L40S => Some(363 * 10u64.pow(12)),
// https://www.nvidia.com/en-us/data-center/products/a10-gpu/ // https://www.nvidia.com/en-us/data-center/products/a10-gpu/
Gpu::A10G => Some(125 * 10u64.pow(12)), Gpu::A10G => Some(125 * 10u64.pow(12)),
// https://www.nvidia.com/en-us/data-center/a40/
// https://images.nvidia.com/content/Solutions/data-center/a40/nvidia-a40-datasheet.pdf
Gpu::A40 => Some(149 * 10u64.pow(12)),
// https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
Gpu::A100 => Some(312 * 10u64.pow(12)),
// https://www.nvidia.com/en-us/data-center/h100/ // https://www.nvidia.com/en-us/data-center/h100/
// https://www.techpowerup.com/gpu-specs/docs/nvidia-gh100-architecture.pdf // https://www.techpowerup.com/gpu-specs/docs/nvidia-gh100-architecture.pdf
Gpu::H100 => Some(900 * 10u64.pow(12)), Gpu::H100 => Some(900 * 10u64.pow(12)),
// https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf // https://www.nvidia.com/en-us/data-center/h200/
Gpu::A100 => Some(312 * 10u64.pow(12)), Gpu::H200 => Some(989 * 10u64.pow(12)),
Gpu::Unknown(card) => { Gpu::Unknown(card) => {
tracing::warn!("Unkown compute for card {card}"); tracing::warn!("Unkown compute for card {card}");
None None
@ -2037,7 +2049,16 @@ fn main() -> Result<(), LauncherError> {
None => { None => {
let compute_type = compute_type(num_shard); let compute_type = compute_type(num_shard);
let compute_optimal = compute_optimal(config.as_ref(), compute_type.as_ref()); let compute_optimal = compute_optimal(config.as_ref(), compute_type.as_ref());
let default = compute_optimal.unwrap_or(4096); // TODO: remove this when we correctly esimate the flops for VLMs
// this is a short term temporary fix to enable vlms to avoid rejecting images
let default_optimal = match config {
Some(ref config) => match config.model_type.as_deref() {
Some("qwen2_vl") => 10_000,
_ => 4096,
},
None => 4096,
};
let default = compute_optimal.unwrap_or(default_optimal);
let vram_maximum = vram_maximum( let vram_maximum = vram_maximum(
config.as_ref(), config.as_ref(),
compute_type.as_ref(), compute_type.as_ref(),
@ -2079,14 +2100,7 @@ fn main() -> Result<(), LauncherError> {
let cuda_graphs = match (&args.cuda_graphs, &quantize) { let cuda_graphs = match (&args.cuda_graphs, &quantize) {
(Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(), (Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
#[allow(deprecated)] #[allow(deprecated)]
( (None, Some(Quantization::Bitsandbytes)) => {
None,
Some(
Quantization::Bitsandbytes
| Quantization::BitsandbytesNf4
| Quantization::BitsandbytesFp4,
),
) => {
tracing::warn!("Bitsandbytes doesn't work with cuda graphs, deactivating them"); tracing::warn!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
vec![] vec![]
} }
@ -2176,11 +2190,12 @@ fn main() -> Result<(), LauncherError> {
} }
// capture adapter_id, path, revision in format of adapter_id=path@revision // capture adapter_id, path, revision in format of adapter_id=path@revision
let re = Regex::new(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$").unwrap(); // path is disabled beforehand.
if let Some(caps) = re.captures(adapter) { let mut splits = adapter.split("@");
let adapter_id = caps.get(1).map_or("", |m| m.as_str()); let adapter_id = splits.next().ok_or_else(|| {
let revision = caps.get(3).map(|m| m.as_str()); LauncherError::ArgumentValidation("Missing adapter id".to_string())
})?;
let revision = splits.next();
download_convert_model( download_convert_model(
adapter_id, adapter_id,
revision, revision,
@ -2190,12 +2205,6 @@ fn main() -> Result<(), LauncherError> {
running.clone(), running.clone(),
false, // avoid merging lora adapters if using multi-lora false, // avoid merging lora adapters if using multi-lora
)?; )?;
} else {
return Err(LauncherError::ArgumentValidation(format!(
"Invalid LoRA adapter format: {}",
adapter
)));
}
} }
} }

View File

@ -94,8 +94,18 @@ mkShell {
( cd clients/python ; python -m pip install --no-dependencies -e . ) ( cd clients/python ; python -m pip install --no-dependencies -e . )
''; '';
postShellHook = '' postShellHook =
''
unset SOURCE_DATE_EPOCH unset SOURCE_DATE_EPOCH
export PATH=$PATH:~/.cargo/bin export PATH=${cudaPackages.backendStdenv.cc}/bin:$PATH:~/.cargo/bin
''
# At various points in time, the latest gcc supported by CUDA differs
# from the default version in nixpkgs. A lot of the dependencies in
# the impure environment pull in the default gcc from nixpkgs, so we
# end up with the CUDA-supported gcc and the nixpkgs default gcc in
# the path. To ensure that we can build CUDA kernels, put the CUDA
# first in the path. It's a hack, but it works.
+ lib.optionalString withCuda ''
export PATH=${cudaPackages.backendStdenv.cc}/bin:$PATH
''; '';
} }

View File

@ -64,6 +64,7 @@ uuid = { version = "1.9.1", default-features = false, features = [
csv = "1.3.0" csv = "1.3.0"
ureq = "=2.9" ureq = "=2.9"
pyo3 = { workspace = true } pyo3 = { workspace = true }
chrono = "0.4.39"
[build-dependencies] [build-dependencies]

View File

@ -224,6 +224,8 @@ pub enum Config {
Qwen2, Qwen2,
Opt, Opt,
T5, T5,
DeepseekV2,
DeepseekV3,
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]

View File

@ -1,5 +1,6 @@
use crate::infer::InferError; use crate::infer::InferError;
use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool}; use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool};
use chrono::Local;
use minijinja::{Environment, ErrorKind, Template}; use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat; use minijinja_contrib::pycompat;
@ -8,6 +9,11 @@ pub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Err
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text)) Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
} }
/// Get the current date in a specific format (custom function), similar to `datetime.now().strftime()` in Python
pub(crate) fn strftime_now(format_str: String) -> Result<String, minijinja::Error> {
Ok(Local::now().format(&format_str).to_string())
}
#[derive(Clone)] #[derive(Clone)]
pub(crate) struct ChatTemplate { pub(crate) struct ChatTemplate {
template: Template<'static, 'static>, template: Template<'static, 'static>,
@ -27,6 +33,7 @@ impl ChatTemplate {
env.set_unknown_method_callback(pycompat::unknown_method_callback); env.set_unknown_method_callback(pycompat::unknown_method_callback);
let template_str = template.into_boxed_str(); let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception); env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);
tracing::debug!("Loading template: {}", template_str); tracing::debug!("Loading template: {}", template_str);
// leaking env and template_str as read-only, static resources for performance. // leaking env and template_str as read-only, static resources for performance.
@ -109,11 +116,12 @@ impl ChatTemplate {
// tests // tests
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::infer::chat_template::raise_exception; use crate::infer::chat_template::{raise_exception, strftime_now};
use crate::infer::ChatTemplate; use crate::infer::ChatTemplate;
use crate::{ use crate::{
ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool, ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool,
}; };
use chrono::Local;
use minijinja::Environment; use minijinja::Environment;
#[test] #[test]
@ -182,6 +190,7 @@ mod tests {
fn test_chat_template_invalid_with_raise() { fn test_chat_template_invalid_with_raise() {
let mut env = Environment::new(); let mut env = Environment::new();
env.add_function("raise_exception", raise_exception); env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);
let source = r#" let source = r#"
{{ bos_token }} {{ bos_token }}
@ -253,6 +262,7 @@ mod tests {
fn test_chat_template_valid_with_raise() { fn test_chat_template_valid_with_raise() {
let mut env = Environment::new(); let mut env = Environment::new();
env.add_function("raise_exception", raise_exception); env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);
let source = r#" let source = r#"
{{ bos_token }} {{ bos_token }}
@ -307,10 +317,79 @@ mod tests {
assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]"); assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]");
} }
#[test]
fn test_chat_template_valid_with_strftime_now() {
let mut env = Environment::new();
env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);
let source = r#"
{% set today = strftime_now("%Y-%m-%d") %}
{% set default_system_message = "The current date is " + today + "." %}
{{ bos_token }}
{% if messages[0]['role'] == 'system' %}
{ set system_message = messages[0]['content'] %}
{%- set loop_messages = messages[1:] %}
{% else %}
{%- set system_message = default_system_message %}
{%- set loop_messages = messages %}
{% endif %}
{{ '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }}
{% for message in loop_messages %}
{% if message['role'] == 'user' %}
{{ '[INST]' + message['content'] + '[/INST]' }}
{% elif message['role'] == 'assistant' %}
{{ message['content'] + eos_token }}
{% else %}
{{ raise_exception('Only user and assistant roles are supported!') }}
{% endif %}
{% endfor %}
"#;
// trim all the whitespace
let source = source
.lines()
.map(|line| line.trim())
.collect::<Vec<&str>>()
.join("");
let tmpl = env.template_from_str(&source);
let chat_template_inputs = ChatTemplateInputs {
messages: vec![
TextMessage {
role: "user".to_string(),
content: "Hi!".to_string(),
},
TextMessage {
role: "assistant".to_string(),
content: "Hello how can I help?".to_string(),
},
TextMessage {
role: "user".to_string(),
content: "What is Deep Learning?".to_string(),
},
TextMessage {
role: "assistant".to_string(),
content: "magic!".to_string(),
},
],
bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"),
add_generation_prompt: true,
..Default::default()
};
let current_date = Local::now().format("%Y-%m-%d").to_string();
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
assert_eq!(result, format!("[BOS][SYSTEM_PROMPT]The current date is {}.[/SYSTEM_PROMPT][INST]Hi![/INST]Hello how can I help?[EOS][INST]What is Deep Learning?[/INST]magic![EOS]", current_date));
}
#[test] #[test]
fn test_chat_template_valid_with_add_generation_prompt() { fn test_chat_template_valid_with_add_generation_prompt() {
let mut env = Environment::new(); let mut env = Environment::new();
env.add_function("raise_exception", raise_exception); env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);
let source = r#" let source = r#"
{% for message in messages %} {% for message in messages %}
@ -502,6 +581,7 @@ mod tests {
{ {
let mut env = Environment::new(); let mut env = Environment::new();
env.add_function("raise_exception", raise_exception); env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);
let tmpl = env.template_from_str(chat_template); let tmpl = env.template_from_str(chat_template);
let result = tmpl.unwrap().render(input).unwrap(); let result = tmpl.unwrap().render(input).unwrap();
assert_eq!(result, target); assert_eq!(result, target);
@ -776,6 +856,7 @@ mod tests {
{ {
let mut env = Environment::new(); let mut env = Environment::new();
env.add_function("raise_exception", raise_exception); env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);
// trim all the whitespace // trim all the whitespace
let chat_template = chat_template let chat_template = chat_template
.lines() .lines()

View File

@ -40,6 +40,8 @@ pub trait Backend {
fn start_health(&self) -> bool { fn start_health(&self) -> bool {
false false
} }
fn name(&self) -> &'static str;
} }
/// Inference struct /// Inference struct

View File

@ -79,7 +79,7 @@ impl TokenizerTrait for tokenizers::Tokenizer {
} }
} }
impl<'a> TokenizerTrait for PyTokenizer<'a> { impl TokenizerTrait for PyTokenizer<'_> {
fn encode_trait( fn encode_trait(
&self, &self,
query: String, query: String,
@ -460,7 +460,7 @@ pub struct CompletionRequest {
pub prompt: Prompt, pub prompt: Prompt,
/// The maximum number of tokens that can be generated in the chat completion. /// The maximum number of tokens that can be generated in the chat completion.
#[serde(default, alias = "max_completion_tokens")] #[serde(default)]
#[schema(default = "1024", example = "32")] #[schema(default = "1024", example = "32")]
pub max_tokens: Option<u32>, pub max_tokens: Option<u32>,
@ -837,7 +837,7 @@ pub(crate) struct ChatRequest {
pub top_logprobs: Option<u32>, pub top_logprobs: Option<u32>,
/// The maximum number of tokens that can be generated in the chat completion. /// The maximum number of tokens that can be generated in the chat completion.
#[serde(default)] #[serde(default, alias = "max_completion_tokens")]
#[schema(default = "1024", example = "32")] #[schema(default = "1024", example = "32")]
pub max_tokens: Option<u32>, pub max_tokens: Option<u32>,

View File

@ -54,6 +54,9 @@ use std::fs::File;
use std::io::BufReader; use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error; use thiserror::Error;
use tokio::select; use tokio::select;
use tokio::signal; use tokio::signal;
@ -1846,9 +1849,9 @@ pub async fn run(
HubTokenizerConfig::default() HubTokenizerConfig::default()
}); });
let tokenizer: Tokenizer = { let tokenizer: Result<Tokenizer, WebServerError> = {
use pyo3::prelude::*; use pyo3::prelude::*;
pyo3::Python::with_gil(|py| -> PyResult<()> { Python::with_gil(|py| -> PyResult<()> {
py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), trust_remote_code)?; py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), trust_remote_code)?;
Ok(()) Ok(())
}) })
@ -1859,16 +1862,16 @@ pub async fn run(
let out = legacy_tokenizer_handle(config_filename.as_ref()); let out = legacy_tokenizer_handle(config_filename.as_ref());
out.ok_or(err) out.ok_or(err)
}) })
.expect("We cannot load a tokenizer"); .map_err(|_| WebServerError::Tokenizer("Unable to load tokenizer.".to_string()))?;
let filename = "out/tokenizer.json"; let filename = "out/tokenizer.json";
if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) { if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) {
Tokenizer::Rust(tok) Ok(Tokenizer::Rust(tok))
} else { } else {
Tokenizer::Python { Ok(Tokenizer::Python {
tokenizer_name: tokenizer_name.clone(), tokenizer_name: tokenizer_name.clone(),
revision: revision.clone(), revision: revision.clone(),
trust_remote_code, trust_remote_code,
} })
} }
}; };
@ -1922,17 +1925,34 @@ pub async fn run(
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
usage_stats_level, usage_stats_level,
backend.name(),
); );
Some(usage_stats::UserAgent::new(reduced_args)) Some(usage_stats::UserAgent::new(reduced_args))
} }
_ => None, _ => None,
}; };
if let Some(ref ua) = user_agent { let stop_usage_thread = Arc::new(AtomicBool::new(false));
let stop_usage_thread_clone = stop_usage_thread.clone();
if let Some(ua) = user_agent.clone() {
let start_event = let start_event =
usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None); usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None);
tokio::spawn(async move { tokio::spawn(async move {
// send start event
start_event.send().await; start_event.send().await;
let mut last_report = Instant::now();
while !stop_usage_thread_clone.load(Ordering::Relaxed) {
if last_report.elapsed() > Duration::from_secs(900) {
let report_event = usage_stats::UsageStatsEvent::new(
ua.clone(),
usage_stats::EventType::Ping,
None,
);
report_event.send().await;
last_report = Instant::now();
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
}); });
}; };
let compat_return_full_text = match &model_info.pipeline_tag { let compat_return_full_text = match &model_info.pipeline_tag {
@ -1953,7 +1973,7 @@ pub async fn run(
validation_workers, validation_workers,
api_key, api_key,
config, config,
(tokenizer, tokenizer_config), (tokenizer?, tokenizer_config),
(preprocessor_config, processor_config), (preprocessor_config, processor_config),
hostname, hostname,
port, port,
@ -1970,6 +1990,7 @@ pub async fn run(
.await; .await;
if let Some(ua) = user_agent { if let Some(ua) = user_agent {
stop_usage_thread.store(true, Ordering::Relaxed);
match result { match result {
Ok(_) => { Ok(_) => {
let stop_event = usage_stats::UsageStatsEvent::new( let stop_event = usage_stats::UsageStatsEvent::new(
@ -2446,8 +2467,13 @@ async fn start(
} }
} else { } else {
// Run server // Run server
let listener = match tokio::net::TcpListener::bind(&addr).await {
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); Ok(listener) => listener,
Err(e) => {
tracing::error!("Failed to bind to {addr}: {e}");
return Err(WebServerError::Axum(Box::new(e)));
}
};
axum::serve(listener, app) axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal()) .with_graceful_shutdown(shutdown_signal())
.await .await
@ -2562,4 +2588,6 @@ impl From<InferError> for Event {
pub enum WebServerError { pub enum WebServerError {
#[error("Axum error: {0}")] #[error("Axum error: {0}")]
Axum(#[from] axum::BoxError), Axum(#[from] axum::BoxError),
#[error("Tokenizer error: {0}")]
Tokenizer(String),
} }

View File

@ -43,6 +43,7 @@ pub enum EventType {
Start, Start,
Stop, Stop,
Error, Error,
Ping,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
@ -70,7 +71,7 @@ impl UsageStatsEvent {
.post(TELEMETRY_URL) .post(TELEMETRY_URL)
.headers(headers) .headers(headers)
.body(body) .body(body)
.timeout(Duration::from_secs(5)) .timeout(Duration::from_secs(10))
.send() .send()
.await; .await;
} }
@ -96,6 +97,7 @@ pub struct Args {
disable_grammar_support: bool, disable_grammar_support: bool,
max_client_batch_size: usize, max_client_batch_size: usize,
usage_stats_level: UsageStatsLevel, usage_stats_level: UsageStatsLevel,
backend_name: &'static str,
} }
impl Args { impl Args {
@ -119,6 +121,7 @@ impl Args {
disable_grammar_support: bool, disable_grammar_support: bool,
max_client_batch_size: usize, max_client_batch_size: usize,
usage_stats_level: UsageStatsLevel, usage_stats_level: UsageStatsLevel,
backend_name: &'static str,
) -> Self { ) -> Self {
Self { Self {
model_config, model_config,
@ -139,6 +142,7 @@ impl Args {
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
usage_stats_level, usage_stats_level,
backend_name,
} }
} }
} }

View File

@ -1229,12 +1229,11 @@ mod tests {
assert!( assert!(
chunks chunks
== vec![ == vec![
Chunk::Text("test".to_string()).into(), Chunk::Text("test".to_string()),
Chunk::Image(Image { Chunk::Image(Image {
data: pixel_data.clone(), data: pixel_data.clone(),
mimetype: "image/gif".to_string() mimetype: "image/gif".to_string()
}) })
.into()
], ],
"Failed to process images", "Failed to process images",
); );
@ -1289,17 +1288,15 @@ mod tests {
assert!( assert!(
chunks chunks
== vec![ == vec![
Chunk::Text("test".to_string()).into(), Chunk::Text("test".to_string()),
Chunk::Image(Image {
data: pixel_data.clone(),
mimetype: "image/gif".to_string()
}),
Chunk::Image(Image { Chunk::Image(Image {
data: pixel_data.clone(), data: pixel_data.clone(),
mimetype: "image/gif".to_string() mimetype: "image/gif".to_string()
}) })
.into(),
Chunk::Image(Image {
data: pixel_data.clone(),
mimetype: "image/gif".to_string()
})
.into()
], ],
"Failed to process images", "Failed to process images",
); );

View File

@ -1,5 +1,5 @@
[toolchain] [toolchain]
# Released on: June 13, 2024 # Released on: June 13, 2024
# https://releases.rs/docs/1.79.0/ # https://releases.rs/docs/1.79.0/
channel = "1.80.1" channel = "1.84.0"
components = ["rustfmt", "clippy"] components = ["rustfmt", "clippy"]

View File

@ -9,11 +9,21 @@ include Makefile-exllamav2
include Makefile-flashinfer include Makefile-flashinfer
unit-tests: unit-tests:
pip install -U pip uv
uv pip install -e ".[dev]"
pytest -s -vv -m "not private" tests pytest -s -vv -m "not private" tests
gen-server: gen-server:
# Compile protos # Compile protos
pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir pip install -U pip uv
uv pip install -r requirements_gen.txt
mkdir text_generation_server/pb || true
python -m grpc_tools.protoc -I../proto/v3 --python_out=text_generation_server/pb \
--grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/v3/generate.proto
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation_server/pb/__init__.py
gen-server-raw:
mkdir text_generation_server/pb || true mkdir text_generation_server/pb || true
python -m grpc_tools.protoc -I../proto/v3 --python_out=text_generation_server/pb \ python -m grpc_tools.protoc -I../proto/v3 --python_out=text_generation_server/pb \
--grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/v3/generate.proto --grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/v3/generate.proto
@ -21,24 +31,20 @@ gen-server:
touch text_generation_server/pb/__init__.py touch text_generation_server/pb/__init__.py
install-server: gen-server install-server: gen-server
pip install pip --upgrade uv pip install -e ".[accelerate, compressed-tensors, quantize, peft, outlines]"
pip install -r requirements_cuda.txt
pip install -e ".[accelerate, compressed-tensors, quantize, peft, outlines]"
install: install-cuda install: install-cuda
echo "Installed server" echo "Installed server"
install-cuda: install-server install-flash-attention-v2-cuda install-flash-attention install-cuda: install-server install-flash-attention-v2-cuda install-flash-attention
pip install -e ".[attention,bnb,marlin,moe]" uv pip install -e ".[attention,bnb,marlin,moe]"
pip install nvidia-nccl-cu12==2.22.3 uv pip install nvidia-nccl-cu12==2.22.3
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
export-requirements: export-requirements:
poetry export -o requirements_cuda.txt --without-hashes uv pip compile pyproject.toml --extra gen -o requirements_gen.txt --python-version 3.11
poetry export -o requirements_rocm.txt --without-hashes uv pip compile pyproject.toml --extra attention --extra bnb --extra accelerate --extra compressed-tensors --extra marlin --extra moe --extra quantize --extra peft --extra outlines -o requirements_cuda.txt --python-version 3.11
poetry export -o requirements_intel.txt --without-hashes uv pip compile pyproject.toml --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines -o requirements_intel.txt --python-version 3.11
uv pip compile pyproject.toml --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines -o requirements_rocm.txt --python-version 3.11

View File

@ -1,5 +1,6 @@
install-flashinfer: install-flashinfer:
# We need fsspec as an additional dependency, but # We need fsspec as an additional dependency, but
# `pip install flashinfer` cannot resolve it. # `pip install flashinfer` cannot resolve it.
pip install fsspec pip install fsspec sympy==1.13.1 numpy
pip install flashinfer==0.2.0.post1 -i https://flashinfer.ai/whl/cu124/torch2.4 pip install -U setuptools
TORCH_CUDA_ARCH_LIST="8.0;8.6;8.9;9.0+PTX" FLASHINFER_ENABLE_AOT=1 pip install git+https://github.com/flashinfer-ai/flashinfer.git@v0.2.0.post1#egg=flashinfer --no-build-isolation

4100
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,96 +1,91 @@
[tool.poetry] [project]
name = "text-generation-server" name = "text-generation-server"
version = "2.0.5-dev0" version = "2.0.5-dev0"
description = "Text Generation Inference Python gRPC Server" description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] readme = "README.md"
requires-python = ">=3.9"
[tool.poetry.scripts] authors = [
text-generation-server = 'text_generation_server.cli:app' {name = "Olivier Dehaene", email = "olivier@huggingface.co"},
{name = "Nicolas Patry", email = "nicolas@huggingface.co"},
[tool.poetry.dependencies]
python = ">=3.9,<3.13"
protobuf = ">=4.25.3,<6"
grpcio = "^1.51.1"
grpcio-status = "^1.51.1"
grpcio-reflection = "^1.51.1"
grpc-interceptor = "^0.15.4"
typer = "^0.12.5"
accelerate = {version = "^1.1.0", optional = true}
bitsandbytes = { version = "^0.43.0", optional = true }
safetensors = "^0.4.5"
loguru = "^0.7.2"
opentelemetry-api = "^1.27.0"
opentelemetry-exporter-otlp = "^1.27.0"
opentelemetry-instrumentation-grpc = "^0.48b0"
hf-transfer = "^0.1.2"
sentencepiece = "^0.2.0"
tokenizers = "^0.20.3"
huggingface-hub = "^0.23"
transformers = "^4.46.2"
einops = "^0.8.0"
texttable = { version = "^1.6.7", optional = true }
datasets = {version = "^2.21.0", optional = true}
peft = {version = "^0.13.2", optional = true}
torch = {version = "^2.4.1", optional = true}
scipy = "^1.13.1"
pillow = "^11.0.0"
outlines= {version = "^0.1.3", optional = true}
prometheus-client = ">=0.20.0,<0.22"
py-cpuinfo = "^9.0.0"
compressed-tensors = {version = "^0.7.1", optional = true}
# Remove later, temporary workaround for outlines.
numpy = "^1.26.4"
attention-kernels = [
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
] ]
dependencies = [
"einops>=0.8.0",
"grpc-interceptor>=0.15.4",
"grpcio>=1.67.0",
"grpcio-reflection>=1.67.0",
"grpcio-status>=1.67.0",
"hf-transfer>=0.1.8",
"loguru>=0.7.3",
"numpy>=1.26,<3",
"opentelemetry-api>=1.27.0",
"opentelemetry-exporter-otlp>=1.27.0",
"opentelemetry-instrumentation-grpc>=0.50b0",
"pillow>=11.1.0",
"prometheus-client>=0.21.0",
"protobuf>=5.28.3",
"py-cpuinfo>=9.0.0",
"rich>=13.8.1",
"safetensors>=0.4.5",
"scipy>=1.13.1",
"sentencepiece>=0.2.0",
"tokenizers>=0.20.3",
"typer>=0.15.1",
"transformers>=4.48.0"
]
[project.scripts]
text-generation-server = "text_generation_server.cli:app"
[project.optional-dependencies]
accelerate = [
"accelerate>=1.2.1,<2",
]
bnb = [
"bitsandbytes>=0.45.0",
]
compressed-tensors = [
"compressed-tensors>=0.9.0",
]
peft = [
"peft>=0.14.0",
]
outlines = [
"outlines>=0.1.13",
]
dev = [
"grpcio-tools>=1.51.1,<2.0",
"pytest>=7.3.0,<8"
]
quantize = [
"texttable>=1.6.7,<2",
"datasets>=2.21,<3",
]
moe = [ "moe-kernels" ]
attention = [ "attention-kernels" ]
marlin = [ "marlin-kernels" ]
gen = [
"grpcio-tools>=1.69.0",
"mypy-protobuf>=3.6.0",
]
[tool.uv.sources]
attention-kernels.url = "https://github.com/danieldk/attention-kernels/releases/download/v0.2.0.post2/attention_kernels-0.2.0.post2+cu123torch2.5-cp39-abi3-linux_x86_64.whl"
marlin-kernels = [ marlin-kernels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp39-cp39-linux_x86_64.whl", marker = "python_version == '3.9'" },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp310-cp310-linux_x86_64.whl", marker = "python_version == '3.10'" },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp311-cp311-linux_x86_64.whl", marker = "python_version == '3.11'" },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp312-cp312-linux_x86_64.whl", marker = "python_version == '3.12'" },
] ]
moe-kernels = [ moe-kernels.url = "https://github.com/danieldk/moe-kernels/releases/download/v0.8.2/moe_kernels-0.8.2+cu123torch2.5-cp39-abi3-linux_x86_64.whl"
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
]
rich = "^13.8.1"
[tool.poetry.extras]
torch = ["torch"]
accelerate = ["accelerate"]
attention = ["attention-kernels"]
bnb = ["bitsandbytes"]
compressed-tensors = ["compressed-tensors"]
marlin = ["marlin-kernels"]
moe = ["moe-kernels"]
peft = ["peft"]
quantize = ["texttable", "datasets", "accelerate"]
outlines = ["outlines"]
[tool.poetry.group.dev.dependencies]
grpcio-tools = "^1.51.1"
pytest = "^7.3.0"
[[tool.poetry.source]]
name = "pytorch-gpu-src"
url = "https://download.pytorch.org/whl/cu121"
priority = "explicit"
[tool.pytest.ini_options] [tool.pytest.ini_options]
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"] markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
[build-system]
requires = [
"poetry-core>=1.0.0",
]
build-backend = "poetry.core.masonry.api"
[tool.isort] [tool.isort]
profile = "black" profile = "black"
[tool.uv]
package = true
[tool.setuptools.packages.find]
include = ["text_generation_server*"]

212
server/req.txt Normal file
View File

@ -0,0 +1,212 @@
# This file was autogenerated by uv via the following command:
# uv pip compile pyproject.toml --extra attention --extra bnb -o req.txt
attention-kernels @ https://github.com/danieldk/attention-kernels/releases/download/v0.2.0.post2/attention_kernels-0.2.0.post2+cu123torch2.5-cp39-abi3-linux_x86_64.whl
# via text-generation-server (pyproject.toml)
bitsandbytes==0.45.1
# via text-generation-server (pyproject.toml)
certifi==2025.1.31
# via requests
charset-normalizer==3.4.1
# via requests
click==8.1.8
# via typer
deprecated==1.2.18
# via
# opentelemetry-api
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
# opentelemetry-semantic-conventions
einops==0.8.0
# via text-generation-server (pyproject.toml)
filelock==3.17.0
# via
# huggingface-hub
# torch
fsspec==2025.2.0
# via
# huggingface-hub
# torch
googleapis-common-protos==1.66.0
# via
# grpcio-status
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
grpc-interceptor==0.15.4
# via text-generation-server (pyproject.toml)
grpcio==1.70.0
# via
# text-generation-server (pyproject.toml)
# grpc-interceptor
# grpcio-reflection
# grpcio-status
# opentelemetry-exporter-otlp-proto-grpc
grpcio-reflection==1.70.0
# via text-generation-server (pyproject.toml)
grpcio-status==1.70.0
# via text-generation-server (pyproject.toml)
hf-transfer==0.1.9
# via text-generation-server (pyproject.toml)
huggingface-hub==0.28.1
# via tokenizers
idna==3.10
# via requests
importlib-metadata==8.5.0
# via opentelemetry-api
jinja2==3.1.5
# via torch
loguru==0.7.3
# via text-generation-server (pyproject.toml)
markdown-it-py==3.0.0
# via rich
markupsafe==3.0.2
# via jinja2
mdurl==0.1.2
# via markdown-it-py
mpmath==1.3.0
# via sympy
networkx==3.4.2
# via torch
numpy==2.2.2
# via
# text-generation-server (pyproject.toml)
# bitsandbytes
# scipy
nvidia-cublas-cu12==12.4.5.8
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.4.127
# via torch
nvidia-cuda-nvrtc-cu12==12.4.127
# via torch
nvidia-cuda-runtime-cu12==12.4.127
# via torch
nvidia-cudnn-cu12==9.1.0.70
# via torch
nvidia-cufft-cu12==11.2.1.3
# via torch
nvidia-curand-cu12==10.3.5.147
# via torch
nvidia-cusolver-cu12==11.6.1.9
# via torch
nvidia-cusparse-cu12==12.3.1.170
# via
# nvidia-cusolver-cu12
# torch
nvidia-cusparselt-cu12==0.6.2
# via torch
nvidia-nccl-cu12==2.21.5
# via torch
nvidia-nvjitlink-cu12==12.4.127
# via
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
# torch
nvidia-nvtx-cu12==12.4.127
# via torch
opentelemetry-api==1.30.0
# via
# text-generation-server (pyproject.toml)
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
# opentelemetry-instrumentation
# opentelemetry-instrumentation-grpc
# opentelemetry-sdk
# opentelemetry-semantic-conventions
opentelemetry-exporter-otlp==1.30.0
# via text-generation-server (pyproject.toml)
opentelemetry-exporter-otlp-proto-common==1.30.0
# via
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
opentelemetry-exporter-otlp-proto-grpc==1.30.0
# via opentelemetry-exporter-otlp
opentelemetry-exporter-otlp-proto-http==1.30.0
# via opentelemetry-exporter-otlp
opentelemetry-instrumentation==0.51b0
# via opentelemetry-instrumentation-grpc
opentelemetry-instrumentation-grpc==0.51b0
# via text-generation-server (pyproject.toml)
opentelemetry-proto==1.30.0
# via
# opentelemetry-exporter-otlp-proto-common
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
opentelemetry-sdk==1.30.0
# via
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
opentelemetry-semantic-conventions==0.51b0
# via
# opentelemetry-instrumentation
# opentelemetry-instrumentation-grpc
# opentelemetry-sdk
packaging==24.2
# via
# huggingface-hub
# opentelemetry-instrumentation
pillow==11.1.0
# via text-generation-server (pyproject.toml)
prometheus-client==0.21.1
# via text-generation-server (pyproject.toml)
protobuf==5.29.3
# via
# text-generation-server (pyproject.toml)
# googleapis-common-protos
# grpcio-reflection
# grpcio-status
# opentelemetry-proto
py-cpuinfo==9.0.0
# via text-generation-server (pyproject.toml)
pygments==2.19.1
# via rich
pyyaml==6.0.2
# via huggingface-hub
requests==2.32.3
# via
# huggingface-hub
# opentelemetry-exporter-otlp-proto-http
rich==13.9.4
# via
# text-generation-server (pyproject.toml)
# typer
safetensors==0.5.2
# via text-generation-server (pyproject.toml)
scipy==1.15.1
# via text-generation-server (pyproject.toml)
sentencepiece==0.2.0
# via text-generation-server (pyproject.toml)
setuptools==75.8.0
# via torch
shellingham==1.5.4
# via typer
sympy==1.13.1
# via torch
tokenizers==0.21.0
# via text-generation-server (pyproject.toml)
torch==2.6.0
# via
# attention-kernels
# bitsandbytes
tqdm==4.67.1
# via huggingface-hub
triton==3.2.0
# via torch
typer==0.15.1
# via text-generation-server (pyproject.toml)
typing-extensions==4.12.2
# via
# huggingface-hub
# opentelemetry-sdk
# torch
# typer
urllib3==2.3.0
# via requests
wrapt==1.17.2
# via
# deprecated
# opentelemetry-instrumentation
# opentelemetry-instrumentation-grpc
zipp==3.21.0
# via importlib-metadata

View File

@ -1,55 +1,384 @@
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13" # This file was autogenerated by uv via the following command:
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13" # uv pip compile pyproject.toml --extra attention --extra bnb --extra accelerate --extra compressed-tensors --extra marlin --extra moe --extra quantize --extra peft --extra outlines -o requirements_cuda.txt --python-version 3.11
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" accelerate==1.3.0
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") # via
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" # text-generation-server (pyproject.toml)
einops==0.8.0 ; python_version >= "3.9" and python_version < "3.13" # peft
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13" aiohappyeyeballs==2.4.4
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13" # via aiohttp
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13" aiohttp==3.11.11
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" # via
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13" # datasets
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13" # fsspec
grpcio==1.68.0 ; python_version >= "3.9" and python_version < "3.13" aiosignal==1.3.2
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" # via aiohttp
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" airportsdata==20241001
idna==3.10 ; python_version >= "3.9" and python_version < "3.13" # via outlines
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" annotated-types==0.7.0
loguru==0.7.2 ; python_version >= "3.9" and python_version < "3.13" # via pydantic
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" attention-kernels @ https://github.com/danieldk/attention-kernels/releases/download/v0.2.0.post2/attention_kernels-0.2.0.post2+cu123torch2.5-cp39-abi3-linux_x86_64.whl
mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" # via text-generation-server (pyproject.toml)
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" attrs==25.1.0
opentelemetry-api==1.27.0 ; python_version >= "3.9" and python_version < "3.13" # via
opentelemetry-exporter-otlp-proto-common==1.27.0 ; python_version >= "3.9" and python_version < "3.13" # aiohttp
opentelemetry-exporter-otlp-proto-grpc==1.27.0 ; python_version >= "3.9" and python_version < "3.13" # jsonschema
opentelemetry-exporter-otlp-proto-http==1.27.0 ; python_version >= "3.9" and python_version < "3.13" # referencing
opentelemetry-exporter-otlp==1.27.0 ; python_version >= "3.9" and python_version < "3.13" bitsandbytes==0.45.1
opentelemetry-instrumentation-grpc==0.48b0 ; python_version >= "3.9" and python_version < "3.13" # via text-generation-server (pyproject.toml)
opentelemetry-instrumentation==0.48b0 ; python_version >= "3.9" and python_version < "3.13" certifi==2024.8.30
opentelemetry-proto==1.27.0 ; python_version >= "3.9" and python_version < "3.13" # via requests
opentelemetry-sdk==1.27.0 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.4.0
opentelemetry-semantic-conventions==0.48b0 ; python_version >= "3.9" and python_version < "3.13" # via requests
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7
pillow==11.0.0 ; python_version >= "3.9" and python_version < "3.13" # via typer
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" cloudpickle==3.1.1
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13" # via outlines
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" compressed-tensors==0.9.1
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" # via text-generation-server (pyproject.toml)
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13" datasets==2.21.0
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13" # via text-generation-server (pyproject.toml)
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" deprecated==1.2.14
rich==13.9.4 ; python_version >= "3.9" and python_version < "3.13" # via
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13" # opentelemetry-api
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" # opentelemetry-exporter-otlp-proto-grpc
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13" # opentelemetry-exporter-otlp-proto-http
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13" # opentelemetry-semantic-conventions
shellingham==1.5.4 ; python_version >= "3.9" and python_version < "3.13" dill==0.3.8
tokenizers==0.20.3 ; python_version >= "3.9" and python_version < "3.13" # via
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13" # datasets
transformers==4.46.3 ; python_version >= "3.9" and python_version < "3.13" # multiprocess
typer==0.12.5 ; python_version >= "3.9" and python_version < "3.13" diskcache==5.6.3
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" # via outlines
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13" einops==0.8.0
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" # via text-generation-server (pyproject.toml)
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" filelock==3.16.1
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13" # via
# datasets
# huggingface-hub
# torch
# transformers
frozenlist==1.5.0
# via
# aiohttp
# aiosignal
fsspec==2024.6.1
# via
# datasets
# huggingface-hub
# torch
genson==1.3.0
# via outlines
googleapis-common-protos==1.65.0
# via
# grpcio-status
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
grpc-interceptor==0.15.4
# via text-generation-server (pyproject.toml)
grpcio==1.68.0
# via
# text-generation-server (pyproject.toml)
# grpc-interceptor
# grpcio-reflection
# grpcio-status
# opentelemetry-exporter-otlp-proto-grpc
grpcio-reflection==1.68.0
# via text-generation-server (pyproject.toml)
grpcio-status==1.68.0
# via text-generation-server (pyproject.toml)
hf-transfer==0.1.8
# via text-generation-server (pyproject.toml)
huggingface-hub==0.28.1
# via
# accelerate
# datasets
# peft
# tokenizers
# transformers
idna==3.10
# via
# requests
# yarl
importlib-metadata==7.1.0
# via opentelemetry-api
interegular==0.3.3
# via
# outlines
# outlines-core
jinja2==3.1.5
# via
# outlines
# torch
jsonschema==4.23.0
# via
# outlines
# outlines-core
jsonschema-specifications==2024.10.1
# via jsonschema
lark==1.2.2
# via outlines
loguru==0.7.3
# via text-generation-server (pyproject.toml)
markdown-it-py==3.0.0
# via rich
markupsafe==3.0.2
# via jinja2
marlin-kernels @ https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp311-cp311-linux_x86_64.whl
# via text-generation-server (pyproject.toml)
mdurl==0.1.2
# via markdown-it-py
moe-kernels @ https://github.com/danieldk/moe-kernels/releases/download/v0.8.2/moe_kernels-0.8.2+cu123torch2.5-cp39-abi3-linux_x86_64.whl
# via text-generation-server (pyproject.toml)
mpmath==1.3.0
# via sympy
multidict==6.1.0
# via
# aiohttp
# yarl
multiprocess==0.70.16
# via datasets
nest-asyncio==1.6.0
# via outlines
networkx==3.4.2
# via torch
numpy==1.26.4
# via
# text-generation-server (pyproject.toml)
# accelerate
# bitsandbytes
# datasets
# outlines
# pandas
# peft
# scipy
# transformers
nvidia-cublas-cu12==12.4.5.8
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.4.127
# via torch
nvidia-cuda-nvrtc-cu12==12.4.127
# via torch
nvidia-cuda-runtime-cu12==12.4.127
# via torch
nvidia-cudnn-cu12==9.1.0.70
# via torch
nvidia-cufft-cu12==11.2.1.3
# via torch
nvidia-curand-cu12==10.3.5.147
# via torch
nvidia-cusolver-cu12==11.6.1.9
# via torch
nvidia-cusparse-cu12==12.3.1.170
# via
# nvidia-cusolver-cu12
# torch
nvidia-cusparselt-cu12==0.6.2
# via torch
nvidia-ml-py==12.570.86
# via moe-kernels
nvidia-nccl-cu12==2.21.5
# via torch
nvidia-nvjitlink-cu12==12.4.127
# via
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
# torch
nvidia-nvtx-cu12==12.4.127
# via torch
opentelemetry-api==1.30.0
# via
# text-generation-server (pyproject.toml)
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
# opentelemetry-instrumentation
# opentelemetry-instrumentation-grpc
# opentelemetry-sdk
# opentelemetry-semantic-conventions
opentelemetry-exporter-otlp==1.30.0
# via text-generation-server (pyproject.toml)
opentelemetry-exporter-otlp-proto-common==1.30.0
# via
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
opentelemetry-exporter-otlp-proto-grpc==1.30.0
# via opentelemetry-exporter-otlp
opentelemetry-exporter-otlp-proto-http==1.30.0
# via opentelemetry-exporter-otlp
opentelemetry-instrumentation==0.51b0
# via opentelemetry-instrumentation-grpc
opentelemetry-instrumentation-grpc==0.51b0
# via text-generation-server (pyproject.toml)
opentelemetry-proto==1.30.0
# via
# opentelemetry-exporter-otlp-proto-common
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
opentelemetry-sdk==1.30.0
# via
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
opentelemetry-semantic-conventions==0.51b0
# via
# opentelemetry-instrumentation
# opentelemetry-instrumentation-grpc
# opentelemetry-sdk
outlines==0.1.14
# via text-generation-server (pyproject.toml)
outlines-core==0.1.26
# via outlines
packaging==24.1
# via
# accelerate
# datasets
# huggingface-hub
# opentelemetry-instrumentation
# peft
# transformers
pandas==2.2.3
# via datasets
peft==0.14.0
# via text-generation-server (pyproject.toml)
pillow==11.1.0
# via text-generation-server (pyproject.toml)
prometheus-client==0.21.1
# via text-generation-server (pyproject.toml)
propcache==0.2.1
# via
# aiohttp
# yarl
protobuf==5.29.3
# via
# text-generation-server (pyproject.toml)
# googleapis-common-protos
# grpcio-reflection
# grpcio-status
# opentelemetry-proto
psutil==6.1.1
# via
# accelerate
# peft
py-cpuinfo==9.0.0
# via text-generation-server (pyproject.toml)
pyarrow==19.0.0
# via datasets
pycountry==24.6.1
# via outlines
pydantic==2.10.6
# via
# compressed-tensors
# outlines
pydantic-core==2.27.2
# via pydantic
pygments==2.18.0
# via rich
python-dateutil==2.9.0.post0
# via pandas
pytz==2025.1
# via pandas
pyyaml==6.0.2
# via
# accelerate
# datasets
# huggingface-hub
# peft
# transformers
referencing==0.36.2
# via
# jsonschema
# jsonschema-specifications
# outlines
regex==2024.9.11
# via transformers
requests==2.32.3
# via
# datasets
# huggingface-hub
# opentelemetry-exporter-otlp-proto-http
# outlines
# transformers
rich==13.9.4
# via
# text-generation-server (pyproject.toml)
# typer
rpds-py==0.22.3
# via
# jsonschema
# referencing
safetensors==0.4.5
# via
# text-generation-server (pyproject.toml)
# accelerate
# peft
# transformers
scipy==1.13.1
# via text-generation-server (pyproject.toml)
sentencepiece==0.2.0
# via text-generation-server (pyproject.toml)
shellingham==1.5.4
# via typer
six==1.17.0
# via python-dateutil
sympy==1.13.1
# via torch
texttable==1.7.0
# via text-generation-server (pyproject.toml)
tokenizers==0.21.0
# via
# text-generation-server (pyproject.toml)
# transformers
torch==2.6.0
# via
# accelerate
# attention-kernels
# bitsandbytes
# compressed-tensors
# marlin-kernels
# moe-kernels
# outlines
# peft
tqdm==4.66.5
# via
# datasets
# huggingface-hub
# outlines
# peft
# transformers
transformers==4.48.2
# via
# text-generation-server (pyproject.toml)
# compressed-tensors
# peft
triton==3.2.0
# via
# moe-kernels
# torch
typer==0.15.1
# via text-generation-server (pyproject.toml)
typing-extensions==4.12.2
# via
# huggingface-hub
# opentelemetry-sdk
# outlines
# pydantic
# pydantic-core
# referencing
# torch
# typer
tzdata==2025.1
# via pandas
urllib3==2.2.3
# via requests
wrapt==1.16.0
# via
# deprecated
# opentelemetry-instrumentation
# opentelemetry-instrumentation-grpc
xxhash==3.5.0
# via datasets
yarl==1.18.3
# via aiohttp
zipp==3.20.2
# via importlib-metadata

180
server/requirements_gen.txt Normal file
View File

@ -0,0 +1,180 @@
# This file was autogenerated by uv via the following command:
# uv pip compile pyproject.toml --extra gen -o requirements_gen.txt --python-version 3.11
certifi==2025.1.31
# via requests
charset-normalizer==3.4.1
# via requests
click==8.1.8
# via typer
deprecated==1.2.18
# via
# opentelemetry-api
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
# opentelemetry-semantic-conventions
einops==0.8.0
# via text-generation-server (pyproject.toml)
filelock==3.17.0
# via
# huggingface-hub
# transformers
fsspec==2025.2.0
# via huggingface-hub
googleapis-common-protos==1.66.0
# via
# grpcio-status
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
grpc-interceptor==0.15.4
# via text-generation-server (pyproject.toml)
grpcio==1.70.0
# via
# text-generation-server (pyproject.toml)
# grpc-interceptor
# grpcio-reflection
# grpcio-status
# grpcio-tools
# opentelemetry-exporter-otlp-proto-grpc
grpcio-reflection==1.70.0
# via text-generation-server (pyproject.toml)
grpcio-status==1.70.0
# via text-generation-server (pyproject.toml)
grpcio-tools==1.70.0
# via text-generation-server (pyproject.toml)
hf-transfer==0.1.9
# via text-generation-server (pyproject.toml)
huggingface-hub==0.28.1
# via
# tokenizers
# transformers
idna==3.10
# via requests
importlib-metadata==8.5.0
# via opentelemetry-api
loguru==0.7.3
# via text-generation-server (pyproject.toml)
markdown-it-py==3.0.0
# via rich
mdurl==0.1.2
# via markdown-it-py
mypy-protobuf==3.6.0
# via text-generation-server (pyproject.toml)
numpy==2.2.2
# via
# text-generation-server (pyproject.toml)
# scipy
# transformers
opentelemetry-api==1.30.0
# via
# text-generation-server (pyproject.toml)
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
# opentelemetry-instrumentation
# opentelemetry-instrumentation-grpc
# opentelemetry-sdk
# opentelemetry-semantic-conventions
opentelemetry-exporter-otlp==1.30.0
# via text-generation-server (pyproject.toml)
opentelemetry-exporter-otlp-proto-common==1.30.0
# via
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
opentelemetry-exporter-otlp-proto-grpc==1.30.0
# via opentelemetry-exporter-otlp
opentelemetry-exporter-otlp-proto-http==1.30.0
# via opentelemetry-exporter-otlp
opentelemetry-instrumentation==0.51b0
# via opentelemetry-instrumentation-grpc
opentelemetry-instrumentation-grpc==0.51b0
# via text-generation-server (pyproject.toml)
opentelemetry-proto==1.30.0
# via
# opentelemetry-exporter-otlp-proto-common
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
opentelemetry-sdk==1.30.0
# via
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
opentelemetry-semantic-conventions==0.51b0
# via
# opentelemetry-instrumentation
# opentelemetry-instrumentation-grpc
# opentelemetry-sdk
packaging==24.2
# via
# huggingface-hub
# opentelemetry-instrumentation
# transformers
pillow==11.1.0
# via text-generation-server (pyproject.toml)
prometheus-client==0.21.1
# via text-generation-server (pyproject.toml)
protobuf==5.29.3
# via
# text-generation-server (pyproject.toml)
# googleapis-common-protos
# grpcio-reflection
# grpcio-status
# grpcio-tools
# mypy-protobuf
# opentelemetry-proto
py-cpuinfo==9.0.0
# via text-generation-server (pyproject.toml)
pygments==2.19.1
# via rich
pyyaml==6.0.2
# via
# huggingface-hub
# transformers
regex==2024.11.6
# via transformers
requests==2.32.3
# via
# huggingface-hub
# opentelemetry-exporter-otlp-proto-http
# transformers
rich==13.9.4
# via
# text-generation-server (pyproject.toml)
# typer
safetensors==0.5.2
# via
# text-generation-server (pyproject.toml)
# transformers
scipy==1.15.1
# via text-generation-server (pyproject.toml)
sentencepiece==0.2.0
# via text-generation-server (pyproject.toml)
setuptools==75.8.0
# via grpcio-tools
shellingham==1.5.4
# via typer
tokenizers==0.21.0
# via
# text-generation-server (pyproject.toml)
# transformers
tqdm==4.67.1
# via
# huggingface-hub
# transformers
transformers==4.48.2
# via text-generation-server (pyproject.toml)
typer==0.15.1
# via text-generation-server (pyproject.toml)
types-protobuf==5.29.1.20241207
# via mypy-protobuf
typing-extensions==4.12.2
# via
# huggingface-hub
# opentelemetry-sdk
# typer
urllib3==2.3.0
# via requests
wrapt==1.17.2
# via
# deprecated
# opentelemetry-instrumentation
# opentelemetry-instrumentation-grpc
zipp==3.21.0
# via importlib-metadata

View File

@ -1,55 +1,367 @@
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13" # This file was autogenerated by uv via the following command:
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13" # uv pip compile pyproject.toml --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines -o requirements_intel.txt --python-version 3.11
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" accelerate==1.3.0
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") # via
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" # text-generation-server (pyproject.toml)
einops==0.8.0 ; python_version >= "3.9" and python_version < "3.13" # peft
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13" aiohappyeyeballs==2.4.4
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13" # via aiohttp
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13" aiohttp==3.11.11
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" # via
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13" # datasets
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13" # fsspec
grpcio==1.68.0 ; python_version >= "3.9" and python_version < "3.13" aiosignal==1.3.2
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" # via aiohttp
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" airportsdata==20241001
idna==3.10 ; python_version >= "3.9" and python_version < "3.13" # via outlines
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" annotated-types==0.7.0
loguru==0.7.2 ; python_version >= "3.9" and python_version < "3.13" # via pydantic
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" attrs==25.1.0
mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" # via
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" # aiohttp
opentelemetry-api==1.27.0 ; python_version >= "3.9" and python_version < "3.13" # jsonschema
opentelemetry-exporter-otlp-proto-common==1.27.0 ; python_version >= "3.9" and python_version < "3.13" # referencing
opentelemetry-exporter-otlp-proto-grpc==1.27.0 ; python_version >= "3.9" and python_version < "3.13" certifi==2024.8.30
opentelemetry-exporter-otlp-proto-http==1.27.0 ; python_version >= "3.9" and python_version < "3.13" # via requests
opentelemetry-exporter-otlp==1.27.0 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.4.0
opentelemetry-instrumentation-grpc==0.48b0 ; python_version >= "3.9" and python_version < "3.13" # via requests
opentelemetry-instrumentation==0.48b0 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7
opentelemetry-proto==1.27.0 ; python_version >= "3.9" and python_version < "3.13" # via typer
opentelemetry-sdk==1.27.0 ; python_version >= "3.9" and python_version < "3.13" cloudpickle==3.1.1
opentelemetry-semantic-conventions==0.48b0 ; python_version >= "3.9" and python_version < "3.13" # via outlines
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" compressed-tensors==0.9.1
pillow==11.0.0 ; python_version >= "3.9" and python_version < "3.13" # via text-generation-server (pyproject.toml)
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" datasets==2.21.0
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13" # via text-generation-server (pyproject.toml)
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" deprecated==1.2.14
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" # via
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13" # opentelemetry-api
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13" # opentelemetry-exporter-otlp-proto-grpc
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" # opentelemetry-exporter-otlp-proto-http
rich==13.9.4 ; python_version >= "3.9" and python_version < "3.13" # opentelemetry-semantic-conventions
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13" dill==0.3.8
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" # via
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13" # datasets
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13" # multiprocess
shellingham==1.5.4 ; python_version >= "3.9" and python_version < "3.13" diskcache==5.6.3
tokenizers==0.20.3 ; python_version >= "3.9" and python_version < "3.13" # via outlines
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13" einops==0.8.0
transformers==4.46.3 ; python_version >= "3.9" and python_version < "3.13" # via text-generation-server (pyproject.toml)
typer==0.12.5 ; python_version >= "3.9" and python_version < "3.13" filelock==3.16.1
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" # via
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13" # datasets
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" # huggingface-hub
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" # torch
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13" # transformers
frozenlist==1.5.0
# via
# aiohttp
# aiosignal
fsspec==2024.6.1
# via
# datasets
# huggingface-hub
# torch
genson==1.3.0
# via outlines
googleapis-common-protos==1.65.0
# via
# grpcio-status
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
grpc-interceptor==0.15.4
# via text-generation-server (pyproject.toml)
grpcio==1.68.0
# via
# text-generation-server (pyproject.toml)
# grpc-interceptor
# grpcio-reflection
# grpcio-status
# opentelemetry-exporter-otlp-proto-grpc
grpcio-reflection==1.68.0
# via text-generation-server (pyproject.toml)
grpcio-status==1.68.0
# via text-generation-server (pyproject.toml)
hf-transfer==0.1.8
# via text-generation-server (pyproject.toml)
huggingface-hub==0.28.1
# via
# accelerate
# datasets
# peft
# tokenizers
# transformers
idna==3.10
# via
# requests
# yarl
importlib-metadata==7.1.0
# via opentelemetry-api
interegular==0.3.3
# via
# outlines
# outlines-core
jinja2==3.1.5
# via
# outlines
# torch
jsonschema==4.23.0
# via
# outlines
# outlines-core
jsonschema-specifications==2024.10.1
# via jsonschema
lark==1.2.2
# via outlines
loguru==0.7.3
# via text-generation-server (pyproject.toml)
markdown-it-py==3.0.0
# via rich
markupsafe==3.0.2
# via jinja2
mdurl==0.1.2
# via markdown-it-py
mpmath==1.3.0
# via sympy
multidict==6.1.0
# via
# aiohttp
# yarl
multiprocess==0.70.16
# via datasets
nest-asyncio==1.6.0
# via outlines
networkx==3.4.2
# via torch
numpy==1.26.4
# via
# text-generation-server (pyproject.toml)
# accelerate
# datasets
# outlines
# pandas
# peft
# scipy
# transformers
nvidia-cublas-cu12==12.4.5.8
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.4.127
# via torch
nvidia-cuda-nvrtc-cu12==12.4.127
# via torch
nvidia-cuda-runtime-cu12==12.4.127
# via torch
nvidia-cudnn-cu12==9.1.0.70
# via torch
nvidia-cufft-cu12==11.2.1.3
# via torch
nvidia-curand-cu12==10.3.5.147
# via torch
nvidia-cusolver-cu12==11.6.1.9
# via torch
nvidia-cusparse-cu12==12.3.1.170
# via
# nvidia-cusolver-cu12
# torch
nvidia-cusparselt-cu12==0.6.2
# via torch
nvidia-nccl-cu12==2.21.5
# via torch
nvidia-nvjitlink-cu12==12.4.127
# via
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
# torch
nvidia-nvtx-cu12==12.4.127
# via torch
opentelemetry-api==1.30.0
# via
# text-generation-server (pyproject.toml)
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
# opentelemetry-instrumentation
# opentelemetry-instrumentation-grpc
# opentelemetry-sdk
# opentelemetry-semantic-conventions
opentelemetry-exporter-otlp==1.30.0
# via text-generation-server (pyproject.toml)
opentelemetry-exporter-otlp-proto-common==1.30.0
# via
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
opentelemetry-exporter-otlp-proto-grpc==1.30.0
# via opentelemetry-exporter-otlp
opentelemetry-exporter-otlp-proto-http==1.30.0
# via opentelemetry-exporter-otlp
opentelemetry-instrumentation==0.51b0
# via opentelemetry-instrumentation-grpc
opentelemetry-instrumentation-grpc==0.51b0
# via text-generation-server (pyproject.toml)
opentelemetry-proto==1.30.0
# via
# opentelemetry-exporter-otlp-proto-common
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
opentelemetry-sdk==1.30.0
# via
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
opentelemetry-semantic-conventions==0.51b0
# via
# opentelemetry-instrumentation
# opentelemetry-instrumentation-grpc
# opentelemetry-sdk
outlines==0.1.14
# via text-generation-server (pyproject.toml)
outlines-core==0.1.26
# via outlines
packaging==24.1
# via
# accelerate
# datasets
# huggingface-hub
# opentelemetry-instrumentation
# peft
# transformers
pandas==2.2.3
# via datasets
peft==0.14.0
# via text-generation-server (pyproject.toml)
pillow==11.1.0
# via text-generation-server (pyproject.toml)
prometheus-client==0.21.1
# via text-generation-server (pyproject.toml)
propcache==0.2.1
# via
# aiohttp
# yarl
protobuf==5.29.3
# via
# text-generation-server (pyproject.toml)
# googleapis-common-protos
# grpcio-reflection
# grpcio-status
# opentelemetry-proto
psutil==6.1.1
# via
# accelerate
# peft
py-cpuinfo==9.0.0
# via text-generation-server (pyproject.toml)
pyarrow==19.0.0
# via datasets
pycountry==24.6.1
# via outlines
pydantic==2.10.6
# via
# compressed-tensors
# outlines
pydantic-core==2.27.2
# via pydantic
pygments==2.18.0
# via rich
python-dateutil==2.9.0.post0
# via pandas
pytz==2025.1
# via pandas
pyyaml==6.0.2
# via
# accelerate
# datasets
# huggingface-hub
# peft
# transformers
referencing==0.36.2
# via
# jsonschema
# jsonschema-specifications
# outlines
regex==2024.9.11
# via transformers
requests==2.32.3
# via
# datasets
# huggingface-hub
# opentelemetry-exporter-otlp-proto-http
# outlines
# transformers
rich==13.9.4
# via
# text-generation-server (pyproject.toml)
# typer
rpds-py==0.22.3
# via
# jsonschema
# referencing
safetensors==0.4.5
# via
# text-generation-server (pyproject.toml)
# accelerate
# peft
# transformers
scipy==1.13.1
# via text-generation-server (pyproject.toml)
sentencepiece==0.2.0
# via text-generation-server (pyproject.toml)
shellingham==1.5.4
# via typer
six==1.17.0
# via python-dateutil
sympy==1.13.1
# via torch
texttable==1.7.0
# via text-generation-server (pyproject.toml)
tokenizers==0.21.0
# via
# text-generation-server (pyproject.toml)
# transformers
torch==2.6.0
# via
# accelerate
# compressed-tensors
# outlines
# peft
tqdm==4.66.5
# via
# datasets
# huggingface-hub
# outlines
# peft
# transformers
transformers==4.48.2
# via
# text-generation-server (pyproject.toml)
# compressed-tensors
# peft
triton==3.2.0
# via torch
typer==0.15.1
# via text-generation-server (pyproject.toml)
typing-extensions==4.12.2
# via
# huggingface-hub
# opentelemetry-sdk
# outlines
# pydantic
# pydantic-core
# referencing
# torch
# typer
tzdata==2025.1
# via pandas
urllib3==2.2.3
# via requests
wrapt==1.16.0
# via
# deprecated
# opentelemetry-instrumentation
# opentelemetry-instrumentation-grpc
xxhash==3.5.0
# via datasets
yarl==1.18.3
# via aiohttp
zipp==3.20.2
# via importlib-metadata

View File

@ -1,55 +1,367 @@
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13" # This file was autogenerated by uv via the following command:
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13" # uv pip compile pyproject.toml --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines -o requirements_rocm.txt --python-version 3.11
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" accelerate==1.3.0
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") # via
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" # text-generation-server (pyproject.toml)
einops==0.8.0 ; python_version >= "3.9" and python_version < "3.13" # peft
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13" aiohappyeyeballs==2.4.4
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13" # via aiohttp
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13" aiohttp==3.11.11
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" # via
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13" # datasets
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13" # fsspec
grpcio==1.68.0 ; python_version >= "3.9" and python_version < "3.13" aiosignal==1.3.2
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" # via aiohttp
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" airportsdata==20241001
idna==3.10 ; python_version >= "3.9" and python_version < "3.13" # via outlines
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" annotated-types==0.7.0
loguru==0.7.2 ; python_version >= "3.9" and python_version < "3.13" # via pydantic
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" attrs==25.1.0
mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" # via
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" # aiohttp
opentelemetry-api==1.27.0 ; python_version >= "3.9" and python_version < "3.13" # jsonschema
opentelemetry-exporter-otlp-proto-common==1.27.0 ; python_version >= "3.9" and python_version < "3.13" # referencing
opentelemetry-exporter-otlp-proto-grpc==1.27.0 ; python_version >= "3.9" and python_version < "3.13" certifi==2024.8.30
opentelemetry-exporter-otlp-proto-http==1.27.0 ; python_version >= "3.9" and python_version < "3.13" # via requests
opentelemetry-exporter-otlp==1.27.0 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.4.0
opentelemetry-instrumentation-grpc==0.48b0 ; python_version >= "3.9" and python_version < "3.13" # via requests
opentelemetry-instrumentation==0.48b0 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7
opentelemetry-proto==1.27.0 ; python_version >= "3.9" and python_version < "3.13" # via typer
opentelemetry-sdk==1.27.0 ; python_version >= "3.9" and python_version < "3.13" cloudpickle==3.1.1
opentelemetry-semantic-conventions==0.48b0 ; python_version >= "3.9" and python_version < "3.13" # via outlines
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" compressed-tensors==0.9.1
pillow==11.0.0 ; python_version >= "3.9" and python_version < "3.13" # via text-generation-server (pyproject.toml)
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" datasets==2.21.0
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13" # via text-generation-server (pyproject.toml)
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" deprecated==1.2.14
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" # via
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13" # opentelemetry-api
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13" # opentelemetry-exporter-otlp-proto-grpc
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" # opentelemetry-exporter-otlp-proto-http
rich==13.9.4 ; python_version >= "3.9" and python_version < "3.13" # opentelemetry-semantic-conventions
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13" dill==0.3.8
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" # via
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13" # datasets
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13" # multiprocess
shellingham==1.5.4 ; python_version >= "3.9" and python_version < "3.13" diskcache==5.6.3
tokenizers==0.20.3 ; python_version >= "3.9" and python_version < "3.13" # via outlines
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13" einops==0.8.0
transformers==4.46.3 ; python_version >= "3.9" and python_version < "3.13" # via text-generation-server (pyproject.toml)
typer==0.12.5 ; python_version >= "3.9" and python_version < "3.13" filelock==3.16.1
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" # via
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13" # datasets
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" # huggingface-hub
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" # torch
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13" # transformers
frozenlist==1.5.0
# via
# aiohttp
# aiosignal
fsspec==2024.6.1
# via
# datasets
# huggingface-hub
# torch
genson==1.3.0
# via outlines
googleapis-common-protos==1.65.0
# via
# grpcio-status
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
grpc-interceptor==0.15.4
# via text-generation-server (pyproject.toml)
grpcio==1.68.0
# via
# text-generation-server (pyproject.toml)
# grpc-interceptor
# grpcio-reflection
# grpcio-status
# opentelemetry-exporter-otlp-proto-grpc
grpcio-reflection==1.68.0
# via text-generation-server (pyproject.toml)
grpcio-status==1.68.0
# via text-generation-server (pyproject.toml)
hf-transfer==0.1.8
# via text-generation-server (pyproject.toml)
huggingface-hub==0.28.1
# via
# accelerate
# datasets
# peft
# tokenizers
# transformers
idna==3.10
# via
# requests
# yarl
importlib-metadata==7.1.0
# via opentelemetry-api
interegular==0.3.3
# via
# outlines
# outlines-core
jinja2==3.1.5
# via
# outlines
# torch
jsonschema==4.23.0
# via
# outlines
# outlines-core
jsonschema-specifications==2024.10.1
# via jsonschema
lark==1.2.2
# via outlines
loguru==0.7.3
# via text-generation-server (pyproject.toml)
markdown-it-py==3.0.0
# via rich
markupsafe==3.0.2
# via jinja2
mdurl==0.1.2
# via markdown-it-py
mpmath==1.3.0
# via sympy
multidict==6.1.0
# via
# aiohttp
# yarl
multiprocess==0.70.16
# via datasets
nest-asyncio==1.6.0
# via outlines
networkx==3.4.2
# via torch
numpy==1.26.4
# via
# text-generation-server (pyproject.toml)
# accelerate
# datasets
# outlines
# pandas
# peft
# scipy
# transformers
nvidia-cublas-cu12==12.4.5.8
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.4.127
# via torch
nvidia-cuda-nvrtc-cu12==12.4.127
# via torch
nvidia-cuda-runtime-cu12==12.4.127
# via torch
nvidia-cudnn-cu12==9.1.0.70
# via torch
nvidia-cufft-cu12==11.2.1.3
# via torch
nvidia-curand-cu12==10.3.5.147
# via torch
nvidia-cusolver-cu12==11.6.1.9
# via torch
nvidia-cusparse-cu12==12.3.1.170
# via
# nvidia-cusolver-cu12
# torch
nvidia-cusparselt-cu12==0.6.2
# via torch
nvidia-nccl-cu12==2.21.5
# via torch
nvidia-nvjitlink-cu12==12.4.127
# via
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
# torch
nvidia-nvtx-cu12==12.4.127
# via torch
opentelemetry-api==1.30.0
# via
# text-generation-server (pyproject.toml)
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
# opentelemetry-instrumentation
# opentelemetry-instrumentation-grpc
# opentelemetry-sdk
# opentelemetry-semantic-conventions
opentelemetry-exporter-otlp==1.30.0
# via text-generation-server (pyproject.toml)
opentelemetry-exporter-otlp-proto-common==1.30.0
# via
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
opentelemetry-exporter-otlp-proto-grpc==1.30.0
# via opentelemetry-exporter-otlp
opentelemetry-exporter-otlp-proto-http==1.30.0
# via opentelemetry-exporter-otlp
opentelemetry-instrumentation==0.51b0
# via opentelemetry-instrumentation-grpc
opentelemetry-instrumentation-grpc==0.51b0
# via text-generation-server (pyproject.toml)
opentelemetry-proto==1.30.0
# via
# opentelemetry-exporter-otlp-proto-common
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
opentelemetry-sdk==1.30.0
# via
# opentelemetry-exporter-otlp-proto-grpc
# opentelemetry-exporter-otlp-proto-http
opentelemetry-semantic-conventions==0.51b0
# via
# opentelemetry-instrumentation
# opentelemetry-instrumentation-grpc
# opentelemetry-sdk
outlines==0.1.14
# via text-generation-server (pyproject.toml)
outlines-core==0.1.26
# via outlines
packaging==24.1
# via
# accelerate
# datasets
# huggingface-hub
# opentelemetry-instrumentation
# peft
# transformers
pandas==2.2.3
# via datasets
peft==0.14.0
# via text-generation-server (pyproject.toml)
pillow==11.1.0
# via text-generation-server (pyproject.toml)
prometheus-client==0.21.1
# via text-generation-server (pyproject.toml)
propcache==0.2.1
# via
# aiohttp
# yarl
protobuf==5.29.3
# via
# text-generation-server (pyproject.toml)
# googleapis-common-protos
# grpcio-reflection
# grpcio-status
# opentelemetry-proto
psutil==6.1.1
# via
# accelerate
# peft
py-cpuinfo==9.0.0
# via text-generation-server (pyproject.toml)
pyarrow==19.0.0
# via datasets
pycountry==24.6.1
# via outlines
pydantic==2.10.6
# via
# compressed-tensors
# outlines
pydantic-core==2.27.2
# via pydantic
pygments==2.18.0
# via rich
python-dateutil==2.9.0.post0
# via pandas
pytz==2025.1
# via pandas
pyyaml==6.0.2
# via
# accelerate
# datasets
# huggingface-hub
# peft
# transformers
referencing==0.36.2
# via
# jsonschema
# jsonschema-specifications
# outlines
regex==2024.9.11
# via transformers
requests==2.32.3
# via
# datasets
# huggingface-hub
# opentelemetry-exporter-otlp-proto-http
# outlines
# transformers
rich==13.9.4
# via
# text-generation-server (pyproject.toml)
# typer
rpds-py==0.22.3
# via
# jsonschema
# referencing
safetensors==0.4.5
# via
# text-generation-server (pyproject.toml)
# accelerate
# peft
# transformers
scipy==1.13.1
# via text-generation-server (pyproject.toml)
sentencepiece==0.2.0
# via text-generation-server (pyproject.toml)
shellingham==1.5.4
# via typer
six==1.17.0
# via python-dateutil
sympy==1.13.1
# via torch
texttable==1.7.0
# via text-generation-server (pyproject.toml)
tokenizers==0.21.0
# via
# text-generation-server (pyproject.toml)
# transformers
torch==2.6.0
# via
# accelerate
# compressed-tensors
# outlines
# peft
tqdm==4.66.5
# via
# datasets
# huggingface-hub
# outlines
# peft
# transformers
transformers==4.48.2
# via
# text-generation-server (pyproject.toml)
# compressed-tensors
# peft
triton==3.2.0
# via torch
typer==0.15.1
# via text-generation-server (pyproject.toml)
typing-extensions==4.12.2
# via
# huggingface-hub
# opentelemetry-sdk
# outlines
# pydantic
# pydantic-core
# referencing
# torch
# typer
tzdata==2025.1
# via pandas
urllib3==2.2.3
# via requests
wrapt==1.16.0
# via
# deprecated
# opentelemetry-instrumentation
# opentelemetry-instrumentation-grpc
xxhash==3.5.0
# via datasets
yarl==1.18.3
# via aiohttp
zipp==3.20.2
# via importlib-metadata

View File

@ -94,6 +94,8 @@ def test_get_mlp_weights_with_gate_up_proj():
# assert the result # assert the result
expected = { expected = {
(3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc),
(3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj),
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj), (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj), (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
@ -188,6 +190,8 @@ def test_get_mlp_weights_llama_compatibility():
result = get_mlp_weights(3, mock_layer) result = get_mlp_weights(3, mock_layer)
expected = { expected = {
(3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc),
(3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj),
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj), (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj), (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
@ -240,6 +244,8 @@ def test_get_mlp_weights_gemma_compatibility():
result = get_mlp_weights(3, mock_layer) result = get_mlp_weights(3, mock_layer)
expected = { expected = {
(3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc),
(3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj),
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj), (3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj), (3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj), (3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),

View File

@ -6,9 +6,11 @@ from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple, Type, Union from typing import Dict, List, Optional, Set, Tuple, Type, Union
from loguru import logger
import torch import torch
from peft import LoraConfig as _LoraConfig from peft import LoraConfig as _LoraConfig
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from text_generation_server.utils.log import log_master
from text_generation_server.adapters.config import AdapterConfig, ModuleMap from text_generation_server.adapters.config import AdapterConfig, ModuleMap
@ -203,8 +205,17 @@ class LoraWeights(AdapterWeights):
lora_a_list = [None] * nlayers lora_a_list = [None] * nlayers
lora_b_list = [None] * nlayers lora_b_list = [None] * nlayers
# import ipdb; ipdb.set_trace()
for layer_id in range(nlayers): for layer_id in range(nlayers):
key = (layer_id, layer_type) key = (layer_id, layer_type)
if key not in target_to_layer:
# There is no layer of this type in the model
log_master(
logger.warning,
f"Key specified in lora weights but not found in base model: {key}",
)
return None
weight_name, layer = target_to_layer[key] weight_name, layer = target_to_layer[key]
base_weight = layer.base_layer.linear.weight base_weight = layer.base_layer.linear.weight
base_device = base_weight.device base_device = base_weight.device

View File

@ -9,6 +9,8 @@ from enum import Enum
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from text_generation_server.utils.adapter import parse_lora_adapters from text_generation_server.utils.adapter import parse_lora_adapters
# Dummy change should cache hit.
app = typer.Typer() app = typer.Typer()

View File

@ -111,6 +111,8 @@ def paged_attention(
out = torch.empty_like(query) out = torch.empty_like(query)
kv_cache_dtype = "fp8" if kv_cache.dtype == torch.float8_e4m3fn else "auto"
use_v1 = max_s <= 8192 and ( use_v1 = max_s <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512 max_num_partitions == 1 or num_seqs * num_heads > 512
) )
@ -120,15 +122,16 @@ def paged_attention(
query, query,
kv_cache.key, kv_cache.key,
kv_cache.value, kv_cache.value,
kv_head_mapping, kv_cache.key.shape[1],
softmax_scale, softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache_dtype,
1.0, kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
) )
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
@ -153,15 +156,16 @@ def paged_attention(
query, query,
kv_cache.key, kv_cache.key,
kv_cache.value, kv_cache.value,
kv_head_mapping, kv_cache.key.shape[1],
softmax_scale, softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache_dtype,
1.0, kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
) )
return out return out
@ -235,7 +239,6 @@ def attention(
paged_kv_cache=(kv_cache.key, kv_cache.value), paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap, logits_soft_cap=softcap,
sm_scale=softmax_scale, sm_scale=softmax_scale,
window_left=window_size_left,
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0, k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0, v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
) )

View File

@ -84,7 +84,7 @@ def use_prefill_with_paged_kv_state(
token = prefill_with_paged_kv_state.set(state) token = prefill_with_paged_kv_state.set(state)
try: try:
state.begin_forward( state.plan(
qo_indptr=cu_seqlens, qo_indptr=cu_seqlens,
paged_kv_indptr=indptr, paged_kv_indptr=indptr,
paged_kv_indices=block_tables, paged_kv_indices=block_tables,
@ -99,7 +99,6 @@ def use_prefill_with_paged_kv_state(
) )
yield yield
finally: finally:
state.end_forward()
if token is not None: if token is not None:
prefill_with_paged_kv_state.reset(token) prefill_with_paged_kv_state.reset(token)
@ -200,7 +199,7 @@ def use_decode_state(
token = decode_state.set(state) token = decode_state.set(state)
try: try:
state.begin_forward( state.plan(
indptr=indptr, indptr=indptr,
indices=block_tables, indices=block_tables,
last_page_len=last_page_len, last_page_len=last_page_len,
@ -214,6 +213,5 @@ def use_decode_state(
) )
yield yield
finally: finally:
state.end_forward()
if token is not None: if token is not None:
decode_state.reset(token) decode_state.reset(token)

View File

@ -1,9 +1,12 @@
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
import torch import torch
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from typing import Optional from typing import Optional
from text_generation_server.models.globals import (
ATTENTION,
BLOCK_SIZE,
)
SUPPORTS_WINDOWING = False SUPPORTS_WINDOWING = False
@ -28,6 +31,22 @@ def attention(
out = torch.empty_like(query) out = torch.empty_like(query)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
if ATTENTION == "flashdecoding-ipex":
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query.contiguous() if query.device.type == "xpu" else query,
kv_cache.key,
kv_cache.value,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
seqlen.max_q,
seqlen.max_k,
softmax_scale,
causal,
block_tables,
None,
)
else:
ipex.llm.functional.varlen_attention( ipex.llm.functional.varlen_attention(
query.contiguous() if query.device.type == "xpu" else query, query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key, key.contiguous() if key.device.type == "xpu" else key,
@ -64,6 +83,23 @@ def paged_attention(
raise NotImplementedError("softcap is not available in IPEX") raise NotImplementedError("softcap is not available in IPEX")
out = torch.empty_like(query) out = torch.empty_like(query)
if ATTENTION == "flashdecoding-ipex":
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query.contiguous() if query.device.type == "xpu" else query,
kv_cache.key,
kv_cache.value,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
seqlen.max_q,
seqlen.max_k,
softmax_scale,
True,
block_tables,
None,
)
else:
input_lengths = seqlen.input_lengths + seqlen.cache_lengths input_lengths = seqlen.input_lengths + seqlen.cache_lengths
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out, out,

View File

@ -52,12 +52,17 @@ class KVCache:
device: torch.device, device: torch.device,
): ):
"""Construct the key-value cache for a layer.""" """Construct the key-value cache for a layer."""
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and ( if not (
ATTENTION != "flashinfer" or SYSTEM != "cuda" (ATTENTION == "flashinfer" and SYSTEM == "cuda")
or (ATTENTION == "paged" and SYSTEM in ("cuda", "rocm"))
): ):
raise ValueError( raise ValueError(
"FP8 KV cache is currently only supported for flashinfer on CUDA" "FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on CUDA and ROCm. "
)
if SYSTEM == "rocm" and dtype == torch.float8_e5m2:
raise ValueError(
"float8_e5m2 FP8 KV cache is not supported on AMD ROCm"
) )
element_size = torch.tensor([], dtype=dtype).element_size() element_size = torch.tensor([], dtype=dtype).element_size()
@ -66,7 +71,9 @@ class KVCache:
else: else:
x = BLOCK_SIZE // element_size x = BLOCK_SIZE // element_size
if ATTENTION in {"flashdecoding", "flashinfer"}: if ATTENTION in {"flashdecoding", "flashinfer"} or (
ATTENTION == "flashdecoding-ipex" and device.type == "xpu"
):
self.kv_cache = ( self.kv_cache = (
torch.empty( torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size), (num_blocks, BLOCK_SIZE, num_heads, head_size),
@ -80,6 +87,7 @@ class KVCache:
), ),
) )
elif SYSTEM == "ipex" and device == torch.device("cpu"): elif SYSTEM == "ipex" and device == torch.device("cpu"):
# ipex cpu flashdecoding kernel and paged attention kernel share same layout
self.kv_cache = ( self.kv_cache = (
torch.empty( torch.empty(
(num_blocks, num_heads, BLOCK_SIZE, head_size), (num_blocks, num_heads, BLOCK_SIZE, head_size),
@ -110,21 +118,17 @@ class KVCache:
"""Check if the cache can be scaled by the given scales.""" """Check if the cache can be scaled by the given scales."""
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0: if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
return False return False
elif ( elif self.dtype == torch.float8_e4m3fn and (
self.dtype == torch.float8_e4m3fn (ATTENTION == "flashinfer" and SYSTEM == "cuda")
and ATTENTION == "flashinfer" or (ATTENTION == "paged" and SYSTEM == "rocm")
and SYSTEM == "cuda"
): ):
log_once( log_once(logger.info, "Using FP8 KV cache scales")
logger.info,
"Using FP8 KV cache scales",
)
return True return True
else: else:
# We have scales, but not the correct FP8 cache type, so warn once. # We have scales, but not the correct FP8 cache type, so warn once.
log_once( log_once(
logger.info, logger.info,
"Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported", "Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm",
) )
return False return False
@ -158,7 +162,7 @@ class KVCache:
key_cache = self.kv_cache[0] key_cache = self.kv_cache[0]
value_cache = self.kv_cache[1] value_cache = self.kv_cache[1]
if self.can_scale(kv_scales): if self.can_scale(kv_scales) and SYSTEM == "cuda":
if kv_scales.key_scale_cpu != 1.0: if kv_scales.key_scale_cpu != 1.0:
key = fp8_quantize( key = fp8_quantize(
key.float(), key.float(),
@ -187,8 +191,22 @@ class KVCache:
shape = key_cache.shape shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value value_cache.view(-1, shape[-2], shape[-1])[slots] = value
elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu":
import intel_extension_for_pytorch as ipex
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
key, value, key_cache, value_cache, slots
)
else: else:
paged_reshape_and_cache(key, value, key_cache, value_cache, slots) paged_reshape_and_cache(
key,
value,
key_cache,
value_cache,
slots,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
)
def paged_reshape_and_cache( def paged_reshape_and_cache(
@ -197,7 +215,10 @@ def paged_reshape_and_cache(
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
): ):
if SYSTEM == "cuda": if SYSTEM == "cuda":
try: try:
import attention_kernels import attention_kernels
@ -205,8 +226,13 @@ def paged_reshape_and_cache(
raise ImportError( raise ImportError(
f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}" f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}"
) )
kv_cache_dtype = "auto"
if key_cache.dtype == torch.float8_e4m3fn:
kv_cache_dtype = "fp8"
attention_kernels.reshape_and_cache( attention_kernels.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0 key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale
) )
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
try: try:
@ -215,8 +241,15 @@ def paged_reshape_and_cache(
raise ImportError( raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
) )
kv_cache_dtype = "auto"
if key_cache.dtype == torch.float8_e4m3fn:
key_cache = key_cache.view(torch.uint8)
value_cache = value_cache.view(torch.uint8)
kv_cache_dtype = "fp8"
ops.reshape_and_cache( ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0 key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale
) )
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex

View File

@ -133,6 +133,15 @@ def paged_attention(
out = torch.empty_like(query) out = torch.empty_like(query)
if kv_cache.dtype == torch.float8_e4m3fn:
key = kv_cache.key.view(torch.uint8)
value = kv_cache.value.view(torch.uint8)
kv_cache_dtype = "fp8"
else:
key = kv_cache.key
value = kv_cache.value
kv_cache_dtype = "auto"
# NOTE(woosuk): We use a simple heuristic to decide whether to use # NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use # PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of # V1 to avoid the overhead of reduction. Also, if the number of
@ -147,8 +156,8 @@ def paged_attention(
ops.paged_attention_v1( ops.paged_attention_v1(
out, out,
query, query,
kv_cache.key, key,
kv_cache.value, value,
num_kv_heads, num_kv_heads,
softmax_scale, softmax_scale,
block_tables, block_tables,
@ -156,9 +165,9 @@ def paged_attention(
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache_dtype,
1.0, kv_scales.key_scale_cpu,
1.0, kv_scales.value_scale_cpu,
) )
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
@ -182,8 +191,8 @@ def paged_attention(
max_logits, max_logits,
tmp_output, tmp_output,
query, query,
kv_cache.key, key,
kv_cache.value, value,
num_kv_heads, num_kv_heads,
softmax_scale, softmax_scale,
block_tables, block_tables,
@ -191,9 +200,9 @@ def paged_attention(
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache_dtype,
1.0, kv_scales.key_scale_cpu,
1.0, kv_scales.value_scale_cpu,
) )
else: else:
ops.paged_attention_rocm( ops.paged_attention_rocm(
@ -202,8 +211,8 @@ def paged_attention(
max_logits, max_logits,
tmp_output, tmp_output,
query, query,
kv_cache.key, key,
kv_cache.value, value,
num_kv_heads, num_kv_heads,
softmax_scale, softmax_scale,
block_tables, block_tables,
@ -211,9 +220,9 @@ def paged_attention(
block_size, block_size,
max_s, max_s,
None, None,
"auto", kv_cache_dtype,
1.0, kv_scales.key_scale_cpu,
1.0, kv_scales.value_scale_cpu,
None, None,
_PARTITION_SIZE, _PARTITION_SIZE,
) )

View File

@ -7,7 +7,7 @@ from text_generation_server.layers.fp8 import (
Fp8Weight, Fp8Weight,
_load_scalar_or_matrix_scale, _load_scalar_or_matrix_scale,
requantize_with_max_scale, requantize_with_max_scale,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_native_float8,
) )
from text_generation_server.utils.weights import Weights, WeightsLoader from text_generation_server.utils.weights import Weights, WeightsLoader
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
@ -147,8 +147,8 @@ class W8ANFpLoader(WeightsLoader):
else None else None
) )
if self.load_weight_scale or SYSTEM == "rocm": if self.load_weight_scale and SYSTEM == "rocm":
w, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( w, weight_scale, input_scale = normalize_e4m3fn_to_native_float8(
w, weight_scale, input_scale w, weight_scale, input_scale
) )

View File

@ -19,6 +19,12 @@ try:
except ImportError: except ImportError:
marlin_kernels = None marlin_kernels = None
try:
from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8
except ImportError:
w8a8_block_fp8_matmul = None
per_token_group_quant_fp8 = None
quant_dtype: torch.dtype = ( quant_dtype: torch.dtype = (
torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn
) )
@ -38,7 +44,6 @@ def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
""" """
if SYSTEM == "cuda": if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability() major, _ = torch.cuda.get_device_capability()
# Marlin is W8A16, use it when: # Marlin is W8A16, use it when:
# #
@ -52,18 +57,28 @@ def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
# gives better decoding throughput on L4 and L40. # gives better decoding throughput on L4 and L40.
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
if major == 8 and minor == 9:
log_once(
logger.info,
"GPU supports FP8, but using Marlin FP8 kernel for better performance",
)
else:
log_once(
logger.info, "GPU does not support FP8, using Marlin FP8 kernel"
)
return GPTQMarlinFP8Linear return GPTQMarlinFP8Linear
# On other systems let Torch decide if the hardware supports FP8. # On other systems let Torch decide if the hardware supports FP8.
return Fp8Linear return Fp8Linear
def normalize_e4m3fn_to_e4m3fnuz( def normalize_e4m3fn_to_native_float8(
weight: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor, weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None, input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
if weight.dtype == torch.float8_e4m3fn: if weight.dtype == torch.float8_e4m3fn and SYSTEM == "rocm":
# The bits pattern 10000000(-128) represents zero in e4m3fn # The bits pattern 10000000(-128) represents zero in e4m3fn
# but NaN in e4m3fnuz. So here we set it to 0. # but NaN in e4m3fnuz. So here we set it to 0.
# https://onnx.ai/onnx/technical/float8.html # https://onnx.ai/onnx/technical/float8.html
@ -162,7 +177,7 @@ def fp8_quantize(
qweight = qweight.to(qdtype) qweight = qweight.to(qdtype)
if SYSTEM == "rocm": if SYSTEM == "rocm":
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale) qweight, scale, _ = normalize_e4m3fn_to_native_float8(qweight, scale)
return qweight, scale return qweight, scale
@ -170,14 +185,29 @@ def fp8_quantize(
class HybridFP8UnquantLoader(WeightsLoader): class HybridFP8UnquantLoader(WeightsLoader):
"""Weight loader that loads FP8 and unquantized Torch tensors.""" """Weight loader that loads FP8 and unquantized Torch tensors."""
def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool): def __init__(
self,
activation_scale_ub: Optional[float],
to_fp8: bool,
weight_block_size: Optional[List[int]] = None,
):
self.activation_scale_ub = activation_scale_ub self.activation_scale_ub = activation_scale_ub
self.to_fp8 = to_fp8 self.to_fp8 = to_fp8
self.weight_block_size = weight_block_size
def get_weights(self, weights: "Weights", prefix: str): def get_weights(self, weights: "Weights", prefix: str):
w = weights.get_tensor(f"{prefix}.weight") w = weights.get_tensor(f"{prefix}.weight")
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
scale = weights.get_tensor(f"{prefix}.weight_scale_inv")
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
# FP8 branch # FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
@ -266,6 +296,21 @@ class HybridFP8UnquantLoader(WeightsLoader):
# FP8 branch # FP8 branch
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
scale = [
weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False)
for p in prefixes
]
scale = torch.cat(scale, dim=dim)
scale = scale.to(weights.device)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
scale = [ scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p, shape in zip(prefixes, shapes) for p, shape in zip(prefixes, shapes)
@ -285,7 +330,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
) )
if SYSTEM == "rocm": if SYSTEM == "rocm":
w, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( w, scale, input_scale = normalize_e4m3fn_to_native_float8(
w, scale, input_scale w, scale, input_scale
) )
@ -311,6 +356,18 @@ class HybridFP8UnquantLoader(WeightsLoader):
w = weights.get_sharded(f"{prefix}.weight", dim=1) w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch # FP8 branch
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
# XXX: Yes the weights is named scale_inv, but corresponds to scale it seems.
scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if SYSTEM == "cuda": if SYSTEM == "cuda":
@ -345,6 +402,7 @@ class Fp8Weight(Weight):
input_scale: Optional[torch.Tensor] = None input_scale: Optional[torch.Tensor] = None
activation_scale_ub: Optional[float] = None activation_scale_ub: Optional[float] = None
force_w8a16: bool = False force_w8a16: bool = False
weight_block_size: Optional[List[int]] = None
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None: if self.weight_scale is None:
@ -361,6 +419,7 @@ class Fp8Weight(Weight):
bias=bias, bias=bias,
input_scale=self.input_scale, input_scale=self.input_scale,
scale_upper_bound=self.activation_scale_ub, scale_upper_bound=self.activation_scale_ub,
weight_block_size=self.weight_block_size,
) )
@ -375,19 +434,21 @@ class Fp8Linear(torch.nn.Module):
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None, input_scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[float] = None, scale_upper_bound: Optional[float] = None,
weight_block_size: Optional[List[int]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if CUTLASS_FP8_AVAILABLE: if CUTLASS_FP8_AVAILABLE:
log_once(logger.info, "Using cutlass w8a8 kernels") log_once(logger.info, "Using cutlass w8a8 kernels")
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn: if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
qweight, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( qweight, scale, input_scale = normalize_e4m3fn_to_native_float8(
weight=qweight, weight_scale=scale weight=qweight, weight_scale=scale, input_scale=input_scale
) )
self.dtype = dtype self.dtype = dtype
self.qweight = qweight self.qweight = qweight
self.scale = scale.float() self.scale = scale.float()
self.input_scale = input_scale.float() if input_scale is not None else None self.input_scale = input_scale.float() if input_scale is not None else None
self.weight_block_size = weight_block_size
if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None: if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None:
self.scale_upper_bound = torch.tensor( self.scale_upper_bound = torch.tensor(
@ -421,6 +482,7 @@ class Fp8Linear(torch.nn.Module):
) -> "Fp8Linear": ) -> "Fp8Linear":
input_scale = kwargs.get("input_scale", None) input_scale = kwargs.get("input_scale", None)
scale_upper_bound = kwargs.get("scale_upper_bound", None) scale_upper_bound = kwargs.get("scale_upper_bound", None)
weight_block_size = kwargs.get("weight_block_size", None)
return cls( return cls(
qweight=weight, qweight=weight,
@ -429,6 +491,7 @@ class Fp8Linear(torch.nn.Module):
scale_upper_bound=scale_upper_bound, scale_upper_bound=scale_upper_bound,
bias=bias, bias=bias,
dtype=dtype, dtype=dtype,
weight_block_size=weight_block_size,
) )
@classmethod @classmethod
@ -440,6 +503,25 @@ class Fp8Linear(torch.nn.Module):
return cls._device_identity_cache[device] return cls._device_identity_cache[device]
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.weight_block_size is not None:
# https://arxiv.org/pdf/2412.19437
# At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and
# scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
# group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
# channels).
qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
output = w8a8_block_fp8_matmul(
qinput,
self.qweight,
scale,
self.scale,
self.weight_block_size,
output_dtype=input.dtype,
)
if self.bias is not None:
output = output + self.bias
return output.to(dtype=input.dtype)
if CUTLASS_FP8_AVAILABLE: if CUTLASS_FP8_AVAILABLE:
# cutlass FP8 supports per-token scales, so get non-scalar scales. # cutlass FP8 supports per-token scales, so get non-scalar scales.
qinput, scale = fp8_quantize( qinput, scale = fp8_quantize(

View File

@ -956,15 +956,24 @@ def quantize(
pack(model, quantizers, bits, groupsize) pack(model, quantizers, bits, groupsize)
from safetensors.torch import save_file from safetensors.torch import save_file
from transformers.modeling_utils import shard_checkpoint from huggingface_hub import split_torch_state_dict_into_shards
state_dict = model.state_dict() state_dict = model.state_dict()
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
max_shard_size = "10GB" max_shard_size = "10GB"
shards, index = shard_checkpoint( state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors" state_dict,
filename_pattern="model.safetensors",
max_shard_size=max_shard_size,
) )
index = None
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
shards = state_dict_split.filename_to_tensors
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
for shard_file, shard in shards.items(): for shard_file, shard in shards.items():
save_file( save_file(

View File

@ -2,14 +2,12 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from loguru import logger
from text_generation_server.layers.fp8 import fp8_quantize from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.layers.marlin.gptq import _check_valid_shape from text_generation_server.layers.marlin.gptq import _check_valid_shape
from text_generation_server.layers.marlin.util import ( from text_generation_server.layers.marlin.util import (
_check_marlin_kernels, _check_marlin_kernels,
permute_scales, permute_scales,
) )
from text_generation_server.utils.log import log_once
try: try:
import marlin_kernels import marlin_kernels
@ -36,8 +34,6 @@ class GPTQMarlinFP8Linear(nn.Module):
_check_marlin_kernels() _check_marlin_kernels()
assert marlin_kernels is not None assert marlin_kernels is not None
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
scales = scales.unsqueeze(0) scales = scales.unsqueeze(0)
if scales.shape[1] == 1: if scales.shape[1] == 1:
out_features, in_features = qweight.shape out_features, in_features = qweight.shape

View File

@ -16,6 +16,7 @@ from text_generation_server.layers.moe.gptq_marlin import (
can_use_marlin_moe_gemm, can_use_marlin_moe_gemm,
) )
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
@ -25,7 +26,7 @@ from text_generation_server.utils.weights import (
) )
if SYSTEM == "ipex": if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE from .fused_moe_ipex import fused_topk, grouped_topk
else: else:
from moe_kernels.fused_moe import fused_topk, grouped_topk from moe_kernels.fused_moe import fused_topk, grouped_topk
@ -51,6 +52,8 @@ class MoELayer(Protocol):
up_proj_name: str = "up_proj", up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj", down_proj_name: str = "down_proj",
hidden_act: str = "silu", hidden_act: str = "silu",
scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None,
): ... ): ...
def forward( def forward(
@ -80,9 +83,14 @@ class DenseMoELayer(nn.Module):
up_proj_name: str = "up_proj", up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj", down_proj_name: str = "down_proj",
hidden_act: str = "silu", hidden_act: str = "silu",
scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None,
): ):
super().__init__() super().__init__()
assert scoring_func is None, "scoring func is not handled"
assert e_score_correction_bias is None, "scoring correction bias is not handled"
log_once( log_once(
logger.info, logger.info,
"No fused layers are available for this model type, using (slower) dense MoE layer", "No fused layers are available for this model type, using (slower) dense MoE layer",
@ -139,10 +147,6 @@ class DenseMoELayer(nn.Module):
) )
for i in range(self.n_experts) for i in range(self.n_experts)
] ]
if SYSTEM == "ipex":
self.ipex_fused_moe = GatedMLPMOE(
W13=self.gate_proj, W2=self.down_proj, W3=self.up_proj, use_prepack=True
)
self.process_group = weights.process_group self.process_group = weights.process_group
@ -155,17 +159,6 @@ class DenseMoELayer(nn.Module):
input_shape = x.shape input_shape = x.shape
x = x.view(-1, input_shape[-1]) x = x.view(-1, input_shape[-1])
if SYSTEM == "ipex":
return self.ipex_fused_moe(
hidden_states=x,
router_logits=gating_output,
top_k=self.topk,
renormalize=self.renormalize,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
if self.n_expert_group is not None and self.topk_group is not None: if self.n_expert_group is not None and self.topk_group is not None:
topk_weights, topk_ids = grouped_topk( topk_weights, topk_ids = grouped_topk(
x, x,
@ -213,16 +206,23 @@ class SparseMoELayer(nn.Module):
topk: int, topk: int,
topk_group: Optional[int], topk_group: Optional[int],
weights: Weights, weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj", gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj", up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj", down_proj_name: str = "down_proj",
): ):
super().__init__() super().__init__()
if ( if (
isinstance(weights.loader, DefaultWeightsLoader) isinstance(weights.loader, DefaultWeightsLoader)
and isinstance(weights.loader.weight_class, UnquantizedWeight) and isinstance(weights.loader.weight_class, UnquantizedWeight)
) or isinstance(weights.loader, HybridFP8UnquantLoader): ) or isinstance(weights.loader, HybridFP8UnquantLoader):
if (
isinstance(weights.loader, HybridFP8UnquantLoader)
and weights.loader.to_fp8
):
cls = FP8SparseMoELayer
else:
cls = UnquantizedSparseMoELayer cls = UnquantizedSparseMoELayer
elif isinstance( elif isinstance(
weights.loader, GPTQMarlinWeightsLoader weights.loader, GPTQMarlinWeightsLoader
@ -250,6 +250,8 @@ class SparseMoELayer(nn.Module):
topk=topk, topk=topk,
topk_group=topk_group, topk_group=topk_group,
weights=weights, weights=weights,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
gate_proj_name=gate_proj_name, gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name, up_proj_name=up_proj_name,
down_proj_name=down_proj_name, down_proj_name=down_proj_name,

Some files were not shown because too many files have changed in this diff Show More