mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-21 16:40:20 +00:00
Merge branch 'huggingface:main' into fix/dockerfile-triton
This commit is contained in:
commit
8ae92e5d70
@ -1,75 +0,0 @@
|
|||||||
ARG CUDA_ARCH_LIST="75-real;80-real;86-real;89-real;90-real"
|
|
||||||
ARG OMPI_VERSION="4.1.7rc1"
|
|
||||||
|
|
||||||
# Build dependencies resolver stage
|
|
||||||
FROM lukemathwalker/cargo-chef:latest AS chef
|
|
||||||
WORKDIR /usr/src/text-generation-inference/backends/trtllm
|
|
||||||
|
|
||||||
FROM chef AS planner
|
|
||||||
COPY . .
|
|
||||||
RUN cargo chef prepare --recipe-path recipe.json
|
|
||||||
|
|
||||||
# CUDA dependent dependencies resolver stage
|
|
||||||
FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu24.04 AS cuda-builder
|
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
|
||||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
|
||||||
apt update && apt install -y \
|
|
||||||
build-essential \
|
|
||||||
cmake \
|
|
||||||
curl \
|
|
||||||
gcc-14 \
|
|
||||||
g++-14 \
|
|
||||||
git \
|
|
||||||
git-lfs \
|
|
||||||
libssl-dev \
|
|
||||||
libucx-dev \
|
|
||||||
ninja-build \
|
|
||||||
pkg-config \
|
|
||||||
pipx \
|
|
||||||
python3 \
|
|
||||||
python3-dev \
|
|
||||||
python3-setuptools \
|
|
||||||
tar \
|
|
||||||
wget && \
|
|
||||||
pipx ensurepath
|
|
||||||
|
|
||||||
ENV TGI_INSTALL_PREFIX=/usr/local/tgi
|
|
||||||
ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt
|
|
||||||
|
|
||||||
# Install OpenMPI
|
|
||||||
FROM cuda-builder AS mpi-builder
|
|
||||||
ARG OMPI_VERSION
|
|
||||||
|
|
||||||
ENV OMPI_TARBALL_FILENAME="openmpi-$OMPI_VERSION.tar.bz2"
|
|
||||||
RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILENAME" -P /opt/src && \
|
|
||||||
mkdir /usr/src/mpi && \
|
|
||||||
tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \
|
|
||||||
cd /usr/src/mpi && \
|
|
||||||
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \
|
|
||||||
make -j all && \
|
|
||||||
make install && \
|
|
||||||
rm -rf "/opt/src/$OMPI_TARBALL_FILENAME"
|
|
||||||
|
|
||||||
# Install TensorRT
|
|
||||||
FROM cuda-builder AS trt-builder
|
|
||||||
COPY backends/trtllm/scripts/install_tensorrt.sh /opt/install_tensorrt.sh
|
|
||||||
RUN chmod +x /opt/install_tensorrt.sh && \
|
|
||||||
/opt/install_tensorrt.sh
|
|
||||||
|
|
||||||
# Build Backend
|
|
||||||
FROM cuda-builder AS tgi-builder
|
|
||||||
WORKDIR /usr/src/text-generation-inference
|
|
||||||
|
|
||||||
# Install Rust
|
|
||||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \
|
|
||||||
chmod -R a+w /root/.rustup && \
|
|
||||||
chmod -R a+w /root/.cargo
|
|
||||||
|
|
||||||
ENV PATH="/root/.cargo/bin:$PATH"
|
|
||||||
RUN cargo install cargo-chef
|
|
||||||
|
|
||||||
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
|
||||||
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
|
||||||
|
|
||||||
ENV MPI_HOME=/usr/local/mpi
|
|
@ -1,19 +0,0 @@
|
|||||||
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
|
|
||||||
// README at: https://github.com/devcontainers/templates/tree/main/src/cpp
|
|
||||||
{
|
|
||||||
"name": "CUDA",
|
|
||||||
"build": {
|
|
||||||
"dockerfile": "Dockerfile_trtllm",
|
|
||||||
"context": ".."
|
|
||||||
},
|
|
||||||
"remoteEnv": {
|
|
||||||
"PATH": "${containerEnv:PATH}:/usr/local/cuda/bin",
|
|
||||||
"LD_LIBRARY_PATH": "$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64",
|
|
||||||
"XLA_FLAGS": "--xla_gpu_cuda_data_dir=/usr/local/cuda"
|
|
||||||
},
|
|
||||||
"customizations" : {
|
|
||||||
"jetbrains" : {
|
|
||||||
"backend" : "CLion"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
61
.github/workflows/build.yaml
vendored
61
.github/workflows/build.yaml
vendored
@ -31,16 +31,28 @@ jobs:
|
|||||||
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
|
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
runs-on:
|
runs-on:
|
||||||
group: aws-highmemory-32-plus-priv
|
group: aws-highmemory-64-plus-priv
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
packages: write
|
packages: write
|
||||||
|
id-token: write
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
- name: Inject slug/short variables
|
- name: Inject slug/short variables
|
||||||
uses: rlespinasse/github-slug-action@v4.4.1
|
uses: rlespinasse/github-slug-action@v4.4.1
|
||||||
- name: Construct harware variables
|
- name: Inject required variables for sccache to interact with Github Actions Cache
|
||||||
|
uses: actions/github-script@v7
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || '');
|
||||||
|
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
|
||||||
|
|
||||||
|
- name: Extract TensorRT-LLM version
|
||||||
|
run: |
|
||||||
|
echo "TENSORRT_LLM_VERSION=$(grep -oP '([a-z,0-9]{40})' $GITHUB_WORKSPACE/backends/trtllm/cmake/trtllm.cmake)" >> $GITHUB_ENV
|
||||||
|
echo "TensorRT-LLM version: ${{ env.TENSORRT_LLM_VERSION }}"
|
||||||
|
- name: Construct hardware variables
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
case ${{ inputs.hardware }} in
|
case ${{ inputs.hardware }} in
|
||||||
@ -52,6 +64,7 @@ jobs:
|
|||||||
export runs_on="aws-g6-12xl-plus-priv-cache"
|
export runs_on="aws-g6-12xl-plus-priv-cache"
|
||||||
export platform=""
|
export platform=""
|
||||||
export extra_pytest=""
|
export extra_pytest=""
|
||||||
|
export target=""
|
||||||
;;
|
;;
|
||||||
cuda-trtllm)
|
cuda-trtllm)
|
||||||
export dockerfile="Dockerfile_trtllm"
|
export dockerfile="Dockerfile_trtllm"
|
||||||
@ -61,15 +74,24 @@ jobs:
|
|||||||
export runs_on="ubuntu-latest"
|
export runs_on="ubuntu-latest"
|
||||||
export platform=""
|
export platform=""
|
||||||
export extra_pytest=""
|
export extra_pytest=""
|
||||||
|
if [[ "${GITHUB_REF}" == refs/tags/* ]]; then
|
||||||
|
export build_type="release";
|
||||||
|
export target="";
|
||||||
|
else
|
||||||
|
export build_type="dev";
|
||||||
|
export target="ci-runtime";
|
||||||
|
fi
|
||||||
;;
|
;;
|
||||||
rocm)
|
rocm)
|
||||||
export dockerfile="Dockerfile_amd"
|
export dockerfile="Dockerfile_amd"
|
||||||
export label_extension="-rocm"
|
export label_extension="-rocm"
|
||||||
export docker_devices="/dev/kfd,/dev/dri"
|
export docker_devices="/dev/kfd,/dev/dri"
|
||||||
export docker_volume="/mnt"
|
export docker_volume="/mnt"
|
||||||
export runs_on="amd-gpu-runners"
|
# This runner was deactivated.
|
||||||
|
export runs_on="ubuntu-latest"
|
||||||
export platform=""
|
export platform=""
|
||||||
export extra_pytest="-k test_flash_gemma_gptq_load"
|
export extra_pytest="-k test_flash_gemma_gptq_load"
|
||||||
|
export target=""
|
||||||
;;
|
;;
|
||||||
intel-xpu)
|
intel-xpu)
|
||||||
export dockerfile="Dockerfile_intel"
|
export dockerfile="Dockerfile_intel"
|
||||||
@ -79,6 +101,7 @@ jobs:
|
|||||||
export runs_on="ubuntu-latest"
|
export runs_on="ubuntu-latest"
|
||||||
export platform="xpu"
|
export platform="xpu"
|
||||||
export extra_pytest=""
|
export extra_pytest=""
|
||||||
|
export target=""
|
||||||
;;
|
;;
|
||||||
intel-cpu)
|
intel-cpu)
|
||||||
export dockerfile="Dockerfile_intel"
|
export dockerfile="Dockerfile_intel"
|
||||||
@ -89,6 +112,7 @@ jobs:
|
|||||||
export runs_on="aws-highmemory-32-plus-priv"
|
export runs_on="aws-highmemory-32-plus-priv"
|
||||||
export platform="cpu"
|
export platform="cpu"
|
||||||
export extra_pytest="-k test_flash_gemma_simple"
|
export extra_pytest="-k test_flash_gemma_simple"
|
||||||
|
export target=""
|
||||||
;;
|
;;
|
||||||
esac
|
esac
|
||||||
echo $dockerfile
|
echo $dockerfile
|
||||||
@ -105,6 +129,8 @@ jobs:
|
|||||||
echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV
|
echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV
|
||||||
echo "EXTRA_PYTEST=${extra_pytest}" >> $GITHUB_ENV
|
echo "EXTRA_PYTEST=${extra_pytest}" >> $GITHUB_ENV
|
||||||
echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV
|
echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV
|
||||||
|
echo "TARGET=${target}" >> $GITHUB_ENV
|
||||||
|
echo "BUILD_TYPE=${build_type}" >> $GITHUB_ENV
|
||||||
- name: Initialize Docker Buildx
|
- name: Initialize Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
with:
|
with:
|
||||||
@ -169,10 +195,15 @@ jobs:
|
|||||||
GIT_SHA=${{ env.GITHUB_SHA }}
|
GIT_SHA=${{ env.GITHUB_SHA }}
|
||||||
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
|
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
|
||||||
PLATFORM=${{ env.PLATFORM }}
|
PLATFORM=${{ env.PLATFORM }}
|
||||||
|
build_type=${{ env.BUILD_TYPE }}
|
||||||
|
sccache_gha_enabled=on
|
||||||
|
actions_cache_url=${{ env.ACTIONS_CACHE_URL }}
|
||||||
|
actions_runtime_token=${{ env.ACTIONS_RUNTIME_TOKEN }}
|
||||||
|
target: ${{ env.TARGET }}
|
||||||
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
||||||
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
|
||||||
cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
|
cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=max
|
||||||
cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
|
cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=max
|
||||||
- name: Final
|
- name: Final
|
||||||
id: final
|
id: final
|
||||||
run: |
|
run: |
|
||||||
@ -214,3 +245,23 @@ jobs:
|
|||||||
echo $DOCKER_IMAGE
|
echo $DOCKER_IMAGE
|
||||||
docker pull $DOCKER_IMAGE
|
docker pull $DOCKER_IMAGE
|
||||||
pytest -s -vv integration-tests ${PYTEST_FLAGS} ${EXTRA_PYTEST}
|
pytest -s -vv integration-tests ${PYTEST_FLAGS} ${EXTRA_PYTEST}
|
||||||
|
|
||||||
|
backend_trtllm_cxx_tests:
|
||||||
|
needs: build-and-push
|
||||||
|
if: needs.build-and-push.outputs.label == '-trtllm'
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.job }}-trtllm-${{ github.head_ref || github.run_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
runs-on:
|
||||||
|
group: aws-g6-12xl-plus-priv-cache
|
||||||
|
container:
|
||||||
|
image: ${{ needs.build-and-push.outputs.docker_image }}
|
||||||
|
credentials:
|
||||||
|
username: ${{ secrets.REGISTRY_USERNAME }}
|
||||||
|
password: ${{ secrets.REGISTRY_PASSWORD }}
|
||||||
|
options: --gpus all --shm-size=8g
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Run C++/CUDA tests
|
||||||
|
if: ${{ env.LABEL == 'ci-runtime' }}
|
||||||
|
run: /usr/local/tgi/bin/tgi_trtllm_backend_tests
|
||||||
|
1
.github/workflows/ci_build.yaml
vendored
1
.github/workflows/ci_build.yaml
vendored
@ -42,6 +42,7 @@ jobs:
|
|||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
packages: write
|
packages: write
|
||||||
|
id-token: write
|
||||||
with:
|
with:
|
||||||
hardware: ${{ matrix.hardware }}
|
hardware: ${{ matrix.hardware }}
|
||||||
# https://github.com/actions/runner/issues/2206
|
# https://github.com/actions/runner/issues/2206
|
||||||
|
8
.github/workflows/tests.yaml
vendored
8
.github/workflows/tests.yaml
vendored
@ -31,7 +31,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
# Released on: 02 May, 2024
|
# Released on: 02 May, 2024
|
||||||
# https://releases.rs/docs/1.78.0/
|
# https://releases.rs/docs/1.78.0/
|
||||||
toolchain: 1.80.0
|
toolchain: 1.84.0
|
||||||
override: true
|
override: true
|
||||||
components: rustfmt, clippy
|
components: rustfmt, clippy
|
||||||
- name: Install Protoc
|
- name: Install Protoc
|
||||||
@ -44,10 +44,14 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
sudo apt update
|
sudo apt update
|
||||||
sudo apt install python3.11-dev -y
|
sudo apt install python3.11-dev -y
|
||||||
|
pip install -U pip uv
|
||||||
|
uv venv
|
||||||
|
source ./.venv/bin/activate
|
||||||
make install-cpu
|
make install-cpu
|
||||||
- name: Run server tests
|
- name: Run server tests
|
||||||
run: |
|
run: |
|
||||||
pip install pytest
|
source ./.venv/bin/activate
|
||||||
|
uv pip install pytest
|
||||||
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||||
pytest -s -vv server/tests
|
pytest -s -vv server/tests
|
||||||
- name: Pre-commit checks
|
- name: Pre-commit checks
|
||||||
|
35
Cargo.lock
generated
35
Cargo.lock
generated
@ -1,6 +1,6 @@
|
|||||||
# This file is automatically @generated by Cargo.
|
# This file is automatically @generated by Cargo.
|
||||||
# It is not intended for manual editing.
|
# It is not intended for manual editing.
|
||||||
version = 3
|
version = 4
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "addr2line"
|
name = "addr2line"
|
||||||
@ -1544,7 +1544,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
|
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"ahash",
|
"ahash",
|
||||||
"allocator-api2",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -2187,9 +2186,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "libc"
|
name = "libc"
|
||||||
version = "0.2.164"
|
version = "0.2.169"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f"
|
checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "libfuzzer-sys"
|
name = "libfuzzer-sys"
|
||||||
@ -4424,14 +4423,14 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-backends-trtllm"
|
name = "text-generation-backends-trtllm"
|
||||||
version = "3.0.2-dev0"
|
version = "3.1.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"clap 4.5.21",
|
"clap 4.5.21",
|
||||||
"cmake",
|
"cmake",
|
||||||
"cxx",
|
"cxx",
|
||||||
"cxx-build",
|
"cxx-build",
|
||||||
"hashbrown 0.14.5",
|
"hashbrown 0.15.1",
|
||||||
"hf-hub",
|
"hf-hub",
|
||||||
"pkg-config",
|
"pkg-config",
|
||||||
"pyo3",
|
"pyo3",
|
||||||
@ -4445,7 +4444,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-benchmark"
|
name = "text-generation-benchmark"
|
||||||
version = "3.0.2-dev0"
|
version = "3.1.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"average",
|
"average",
|
||||||
"clap 4.5.21",
|
"clap 4.5.21",
|
||||||
@ -4465,7 +4464,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-client"
|
name = "text-generation-client"
|
||||||
version = "3.0.2-dev0"
|
version = "3.1.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
@ -4483,7 +4482,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-launcher"
|
name = "text-generation-launcher"
|
||||||
version = "3.0.2-dev0"
|
version = "3.1.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap 4.5.21",
|
"clap 4.5.21",
|
||||||
"ctrlc",
|
"ctrlc",
|
||||||
@ -4504,7 +4503,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router"
|
name = "text-generation-router"
|
||||||
version = "3.0.2-dev0"
|
version = "3.1.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"async-stream",
|
"async-stream",
|
||||||
@ -4555,7 +4554,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router-v2"
|
name = "text-generation-router-v2"
|
||||||
version = "3.0.2-dev0"
|
version = "3.1.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
@ -4604,7 +4603,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router-v3"
|
name = "text-generation-router-v3"
|
||||||
version = "3.0.2-dev0"
|
version = "3.1.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
@ -4791,9 +4790,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio"
|
name = "tokio"
|
||||||
version = "1.41.1"
|
version = "1.43.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33"
|
checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"backtrace",
|
"backtrace",
|
||||||
"bytes",
|
"bytes",
|
||||||
@ -4819,9 +4818,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio-macros"
|
name = "tokio-macros"
|
||||||
version = "2.4.0"
|
version = "2.5.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752"
|
checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
@ -4862,9 +4861,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio-stream"
|
name = "tokio-stream"
|
||||||
version = "0.1.16"
|
version = "0.1.17"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4f4e6ce100d0eb49a2734f8c0812bcd324cf357d21810932c5df6b96ef2b86f1"
|
checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
|
@ -20,7 +20,7 @@ default-members = [
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "3.0.2-dev0"
|
version = "3.1.1-dev0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Olivier Dehaene"]
|
authors = ["Olivier Dehaene"]
|
||||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||||
|
16
Dockerfile
16
Dockerfile
@ -1,5 +1,5 @@
|
|||||||
# Rust builder
|
# Rust builder
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.84.0 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
@ -47,7 +47,7 @@ RUN cargo build --profile release-opt --frozen
|
|||||||
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
|
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
|
||||||
|
|
||||||
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
|
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
|
||||||
ARG PYTORCH_VERSION=2.4.0
|
ARG PYTORCH_VERSION=2.5.1
|
||||||
|
|
||||||
ARG PYTHON_VERSION=3.11
|
ARG PYTHON_VERSION=3.11
|
||||||
# Keep in sync with `server/pyproject.toml
|
# Keep in sync with `server/pyproject.toml
|
||||||
@ -58,7 +58,7 @@ ARG INSTALL_CHANNEL=pytorch
|
|||||||
# Automatically set by buildx
|
# Automatically set by buildx
|
||||||
ARG TARGETPLATFORM
|
ARG TARGETPLATFORM
|
||||||
|
|
||||||
ENV PATH /opt/conda/bin:$PATH
|
ENV PATH=/opt/conda/bin:$PATH
|
||||||
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
build-essential \
|
build-essential \
|
||||||
@ -224,17 +224,19 @@ COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-
|
|||||||
COPY --from=flashinfer-builder /opt/conda/lib/python3.11/site-packages/flashinfer/ /opt/conda/lib/python3.11/site-packages/flashinfer/
|
COPY --from=flashinfer-builder /opt/conda/lib/python3.11/site-packages/flashinfer/ /opt/conda/lib/python3.11/site-packages/flashinfer/
|
||||||
|
|
||||||
# Install flash-attention dependencies
|
# Install flash-attention dependencies
|
||||||
RUN pip install einops --no-cache-dir
|
# RUN pip install einops --no-cache-dir
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY server server
|
COPY server server
|
||||||
COPY server/Makefile server/Makefile
|
COPY server/Makefile server/Makefile
|
||||||
|
ENV UV_SYSTEM_PYTHON=1
|
||||||
RUN cd server && \
|
RUN cd server && \
|
||||||
make gen-server && \
|
make gen-server && \
|
||||||
pip install -r requirements_cuda.txt && \
|
python -c "from text_generation_server.pb import generate_pb2" && \
|
||||||
pip install ".[attention, bnb, accelerate, compressed-tensors, marlin, moe, quantize, peft, outlines]" --no-cache-dir && \
|
pip install -U pip uv && \
|
||||||
pip install nvidia-nccl-cu12==2.22.3
|
uv pip install -e ".[attention, bnb, accelerate, compressed-tensors, marlin, moe, quantize, peft, outlines]" --no-cache-dir # && \
|
||||||
|
# uv pip install nvidia-nccl-cu12==2.22.3
|
||||||
|
|
||||||
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
|
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
|
||||||
# Required to find libpython within the rust binaries
|
# Required to find libpython within the rust binaries
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Rust builder
|
# Rust builder
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.84.0 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
@ -104,7 +104,7 @@ RUN case ${TARGETPLATFORM} in \
|
|||||||
/opt/conda/bin/conda clean -ya
|
/opt/conda/bin/conda clean -ya
|
||||||
|
|
||||||
# Install flash-attention, torch dependencies
|
# Install flash-attention, torch dependencies
|
||||||
RUN python3 -m pip install --upgrade pip && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/*
|
RUN python3 -m pip install --upgrade pip uv && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
RUN conda install mkl=2021
|
RUN conda install mkl=2021
|
||||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/
|
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/
|
||||||
@ -268,9 +268,18 @@ COPY server/exllamav2_kernels/ .
|
|||||||
|
|
||||||
RUN python setup.py build
|
RUN python setup.py build
|
||||||
|
|
||||||
|
FROM kernel-builder AS marlin-kernels
|
||||||
|
WORKDIR /usr/src
|
||||||
|
ENV MARLIN_KERNELS_BRANCH=v0.3.6
|
||||||
|
ENV VLLM_TARGET_DEVICE=rocm
|
||||||
|
RUN git clone https://github.com/danieldk/marlin-kernels.git && \
|
||||||
|
cd marlin-kernels && \
|
||||||
|
git checkout ${MARLIN_KERNELS_BRANCH} && \
|
||||||
|
python setup.py install
|
||||||
|
|
||||||
FROM kernel-builder AS moe-kernels
|
FROM kernel-builder AS moe-kernels
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
ENV MOE_KERNELS_BRANCH=a67b35841774b2056a73806c36661134b5054edd
|
ENV MOE_KERNELS_BRANCH=v0.8.2
|
||||||
ENV VLLM_TARGET_DEVICE=rocm
|
ENV VLLM_TARGET_DEVICE=rocm
|
||||||
RUN git clone https://github.com/danieldk/moe-kernels.git && \
|
RUN git clone https://github.com/danieldk/moe-kernels.git && \
|
||||||
cd moe-kernels && \
|
cd moe-kernels && \
|
||||||
@ -299,6 +308,9 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311
|
|||||||
# Copy build artifacts from exllamav2 kernels builder
|
# Copy build artifacts from exllamav2 kernels builder
|
||||||
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||||
|
|
||||||
|
# Copy build artifacts from marlin kernels
|
||||||
|
COPY --from=marlin-kernels /usr/src/marlin-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||||
|
|
||||||
# Copy build artifacts from moe kernels
|
# Copy build artifacts from moe kernels
|
||||||
COPY --from=moe-kernels /usr/src/moe-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
COPY --from=moe-kernels /usr/src/moe-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||||
|
|
||||||
@ -306,10 +318,11 @@ COPY --from=moe-kernels /usr/src/moe-kernels/build/lib.linux-x86_64-cpython-311
|
|||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY server server
|
COPY server server
|
||||||
COPY server/Makefile server/Makefile
|
COPY server/Makefile server/Makefile
|
||||||
|
ENV UV_SYSTEM_PYTHON=1
|
||||||
RUN cd server && \
|
RUN cd server && \
|
||||||
make gen-server && \
|
make gen-server && \
|
||||||
pip install -r requirements_rocm.txt && \
|
pip install -U pip uv && \
|
||||||
pip install ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
|
uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
|
||||||
|
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
ARG PLATFORM=xpu
|
ARG PLATFORM=xpu
|
||||||
|
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.84.0 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
@ -108,17 +108,19 @@ RUN pip install triton-xpu==3.0.0b2 --no-cache-dir
|
|||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY server server
|
COPY server server
|
||||||
COPY server/Makefile server/Makefile
|
COPY server/Makefile server/Makefile
|
||||||
|
ENV UV_SYSTEM_PYTHON=1
|
||||||
RUN cd server && \
|
RUN cd server && \
|
||||||
make gen-server && \
|
make gen-server && \
|
||||||
pip install -r requirements_intel.txt && \
|
pip install -U pip uv && \
|
||||||
pip install ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
|
uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
|
||||||
|
|
||||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/pti/0.9/lib:/opt/conda/lib
|
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/pti/0.9/lib:/opt/conda/lib
|
||||||
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
||||||
#ENV TORCH_LLM_ALLREDUCE=1
|
#ENV TORCH_LLM_ALLREDUCE=1
|
||||||
#ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
|
#ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
|
||||||
|
ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0
|
||||||
|
|
||||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 033af6f63745ac748cccdadee5c6140c7971edf6
|
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 1ccf72b2d11cd00b47aef6d6cd054c088aa6f083
|
||||||
RUN cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc,ats-m150' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
|
RUN cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc,ats-m150' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
|
||||||
|
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
@ -211,10 +213,11 @@ ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
|
|||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY server server
|
COPY server server
|
||||||
COPY server/Makefile server/Makefile
|
COPY server/Makefile server/Makefile
|
||||||
|
ENV UV_SYSTEM_PYTHON=1
|
||||||
RUN cd server && \
|
RUN cd server && \
|
||||||
make gen-server && \
|
make gen-server && \
|
||||||
pip install -r requirements_intel.txt && \
|
pip install -U pip uv && \
|
||||||
pip install ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
|
uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
|
||||||
|
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
@ -224,9 +227,9 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
|
|||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
FROM ${PLATFORM} AS final
|
FROM ${PLATFORM} AS final
|
||||||
ENV ATTENTION=paged
|
ENV ATTENTION=flashdecoding-ipex
|
||||||
ENV PREFIX_CACHING=0
|
ENV PREFIX_CACHING=1
|
||||||
ENV PREFILL_CHUNKING=0
|
ENV PREFILL_CHUNKING=1
|
||||||
ENV CUDA_GRAPHS=0
|
ENV CUDA_GRAPHS=0
|
||||||
ENTRYPOINT ["text-generation-launcher"]
|
ENTRYPOINT ["text-generation-launcher"]
|
||||||
CMD ["--json-output"]
|
CMD ["--json-output"]
|
||||||
|
@ -1,20 +1,14 @@
|
|||||||
ARG CUDA_ARCH_LIST="75-real;80-real;86-real;89-real;90-real"
|
ARG cuda_arch_list="75-real;80-real;86-real;89-real;90-real"
|
||||||
ARG OMPI_VERSION="4.1.7rc1"
|
ARG build_type=release
|
||||||
|
ARG ompi_version=4.1.7
|
||||||
# Build dependencies resolver stage
|
ARG sccache_gha_enabled=off
|
||||||
FROM lukemathwalker/cargo-chef:latest AS chef
|
ARG actions_cache_url=""
|
||||||
WORKDIR /usr/src/text-generation-inference/backends/trtllm
|
ARG actions_runtime_token=""
|
||||||
|
|
||||||
FROM chef AS planner
|
|
||||||
COPY . .
|
|
||||||
RUN cargo chef prepare --recipe-path recipe.json
|
|
||||||
|
|
||||||
# CUDA dependent dependencies resolver stage
|
# CUDA dependent dependencies resolver stage
|
||||||
FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu24.04 AS cuda-builder
|
FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu24.04 AS cuda-builder
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
||||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
|
||||||
apt update && apt install -y \
|
|
||||||
build-essential \
|
build-essential \
|
||||||
cmake \
|
cmake \
|
||||||
curl \
|
curl \
|
||||||
@ -22,8 +16,11 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
|||||||
g++-14 \
|
g++-14 \
|
||||||
git \
|
git \
|
||||||
git-lfs \
|
git-lfs \
|
||||||
|
lld \
|
||||||
libssl-dev \
|
libssl-dev \
|
||||||
libucx-dev \
|
libucx-dev \
|
||||||
|
libasan8 \
|
||||||
|
libubsan1 \
|
||||||
ninja-build \
|
ninja-build \
|
||||||
pkg-config \
|
pkg-config \
|
||||||
pipx \
|
pipx \
|
||||||
@ -31,7 +28,7 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
|||||||
python3-dev \
|
python3-dev \
|
||||||
python3-setuptools \
|
python3-setuptools \
|
||||||
tar \
|
tar \
|
||||||
wget && \
|
wget --no-install-recommends && \
|
||||||
pipx ensurepath
|
pipx ensurepath
|
||||||
|
|
||||||
ENV TGI_INSTALL_PREFIX=/usr/local/tgi
|
ENV TGI_INSTALL_PREFIX=/usr/local/tgi
|
||||||
@ -39,17 +36,19 @@ ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt
|
|||||||
|
|
||||||
# Install OpenMPI
|
# Install OpenMPI
|
||||||
FROM cuda-builder AS mpi-builder
|
FROM cuda-builder AS mpi-builder
|
||||||
ARG OMPI_VERSION
|
WORKDIR /opt/src/mpi
|
||||||
|
|
||||||
ENV OMPI_TARBALL_FILENAME="openmpi-$OMPI_VERSION.tar.bz2"
|
ARG ompi_version
|
||||||
RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILENAME" -P /opt/src && \
|
ENV OMPI_VERSION=${ompi_version}
|
||||||
mkdir /usr/src/mpi && \
|
ENV OMPI_TARBALL_FILENAME=openmpi-${OMPI_VERSION}.tar.bz2
|
||||||
tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \
|
ADD --checksum=sha256:54a33cb7ad81ff0976f15a6cc8003c3922f0f3d8ceed14e1813ef3603f22cd34 \
|
||||||
cd /usr/src/mpi && \
|
https://download.open-mpi.org/release/open-mpi/v4.1/${OMPI_TARBALL_FILENAME} .
|
||||||
|
|
||||||
|
RUN tar --strip-components=1 -xf ${OMPI_TARBALL_FILENAME} &&\
|
||||||
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \
|
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \
|
||||||
make -j all && \
|
make -j all && \
|
||||||
make install && \
|
make install && \
|
||||||
rm -rf "/opt/src/$OMPI_TARBALL_FILENAME"
|
rm -rf ${OMPI_TARBALL_FILENAME}/..
|
||||||
|
|
||||||
# Install TensorRT
|
# Install TensorRT
|
||||||
FROM cuda-builder AS trt-builder
|
FROM cuda-builder AS trt-builder
|
||||||
@ -61,30 +60,50 @@ RUN chmod +x /opt/install_tensorrt.sh && \
|
|||||||
FROM cuda-builder AS tgi-builder
|
FROM cuda-builder AS tgi-builder
|
||||||
WORKDIR /usr/src/text-generation-inference
|
WORKDIR /usr/src/text-generation-inference
|
||||||
|
|
||||||
|
# Scoped global args reuse
|
||||||
|
ARG cuda_arch_list
|
||||||
|
ARG build_type
|
||||||
|
ARG sccache_gha_enabled
|
||||||
|
ARG actions_cache_url
|
||||||
|
ARG actions_runtime_token
|
||||||
|
|
||||||
# Install Rust
|
# Install Rust
|
||||||
|
ENV PATH="/root/.cargo/bin:$PATH"
|
||||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \
|
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \
|
||||||
chmod -R a+w /root/.rustup && \
|
chmod -R a+w /root/.rustup && \
|
||||||
chmod -R a+w /root/.cargo
|
chmod -R a+w /root/.cargo && \
|
||||||
|
cargo install sccache --locked
|
||||||
|
|
||||||
ENV PATH="/root/.cargo/bin:$PATH"
|
|
||||||
RUN cargo install cargo-chef
|
|
||||||
|
|
||||||
# Cache dependencies
|
|
||||||
COPY --from=planner /usr/src/text-generation-inference/backends/trtllm/recipe.json .
|
|
||||||
RUN cargo chef cook --release --recipe-path recipe.json
|
|
||||||
|
|
||||||
# Build actual TGI
|
|
||||||
ARG CUDA_ARCH_LIST
|
|
||||||
ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt:$CMAKE_PREFIX_PATH"
|
|
||||||
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
|
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
|
||||||
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig:$PKG_CONFIG_PATH"
|
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig"
|
||||||
|
ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt"
|
||||||
|
|
||||||
COPY . .
|
ENV USE_LLD_LINKER=ON
|
||||||
|
ENV CUDA_ARCH_LIST=${cuda_arch_list}
|
||||||
|
|
||||||
|
# SCCACHE Specifics args - before finding a better, more generic, way...
|
||||||
|
ENV SCCACHE_GHA_ENABLED=${sccache_gha_enabled}
|
||||||
|
ENV ACTIONS_CACHE_URL=${actions_cache_url}
|
||||||
|
ENV ACTIONS_RUNTIME_TOKEN=${actions_runtime_token}
|
||||||
|
|
||||||
|
COPY Cargo.lock Cargo.lock
|
||||||
|
COPY Cargo.toml Cargo.toml
|
||||||
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
COPY router router
|
||||||
|
COPY backends backends
|
||||||
|
COPY benchmark benchmark
|
||||||
|
COPY launcher launcher
|
||||||
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
||||||
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
||||||
RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \
|
|
||||||
cd backends/trtllm && \
|
ENV RUSTC_WRAPPER=sccache
|
||||||
CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release
|
ENV CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX
|
||||||
|
RUN export CMAKE_C_COMPILER_LAUNCHER=sccache && \
|
||||||
|
export CMAKE_CXX_COMPILER_LAUNCHER=sccache && \
|
||||||
|
export CMAKE_CUDA_COMPILER_LAUNCHER=sccache && \
|
||||||
|
mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \
|
||||||
|
cargo build --profile ${build_type} --package text-generation-backends-trtllm --bin text-generation-backends-trtllm && \
|
||||||
|
sccache --show-stats
|
||||||
|
|
||||||
FROM nvidia/cuda:12.6.3-cudnn-runtime-ubuntu24.04 AS runtime
|
FROM nvidia/cuda:12.6.3-cudnn-runtime-ubuntu24.04 AS runtime
|
||||||
RUN apt update && apt install -y libucx0 pipx python3-minimal python3-dev python3-pip python3-venv && \
|
RUN apt update && apt install -y libucx0 pipx python3-minimal python3-dev python3-pip python3-venv && \
|
||||||
@ -104,10 +123,33 @@ COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
|||||||
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
|
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
|
||||||
COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
|
COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
|
||||||
|
|
||||||
|
# This is used only for the CI/CD
|
||||||
|
FROM nvidia/cuda:12.6.3-cudnn-runtime-ubuntu24.04 AS ci-runtime
|
||||||
|
RUN apt update && apt install -y libasan8 libubsan1 libucx0 pipx python3-minimal python3-dev python3-pip python3-venv && \
|
||||||
|
rm -rf /var/lib/{apt,dpkg,cache,log}/ && \
|
||||||
|
pipx ensurepath && \
|
||||||
|
pipx install --include-deps transformers tokenizers
|
||||||
|
|
||||||
|
WORKDIR /usr/local/tgi/bin
|
||||||
|
|
||||||
|
ENV PATH=/root/.local/share/pipx/venvs/transformers/bin/:$PATH
|
||||||
|
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/mpi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
|
||||||
|
ENV TOKENIZERS_PARALLELISM=false
|
||||||
|
ENV OMPI_MCA_plm_rsh_agent=""
|
||||||
|
|
||||||
|
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
||||||
|
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
||||||
|
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
|
||||||
|
|
||||||
|
# Basically we copy from target/debug instead of target/release
|
||||||
|
COPY --from=tgi-builder /usr/src/text-generation-inference/target/debug/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
|
||||||
|
|
||||||
|
# This is the final image
|
||||||
FROM runtime
|
FROM runtime
|
||||||
|
|
||||||
LABEL co.huggingface.vendor="Hugging Face Inc."
|
LABEL co.huggingface.vendor="Hugging Face Inc."
|
||||||
LABEL org.opencontainers.image.authors="hardware@hf.co"
|
LABEL org.opencontainers.image.authors="hardware@hf.co"
|
||||||
|
LABEL org.opencontainers.title="Text-Generation-Inference TensorRT-LLM Backend"
|
||||||
|
|
||||||
ENTRYPOINT ["./text-generation-launcher"]
|
ENTRYPOINT ["./text-generation-launcher"]
|
||||||
CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"]
|
CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"]
|
||||||
|
@ -84,7 +84,7 @@ model=HuggingFaceH4/zephyr-7b-beta
|
|||||||
volume=$PWD/data
|
volume=$PWD/data
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||||
3.0.0 ghcr.io/huggingface/text-generation-inference:3.0.0 --model-id $model
|
ghcr.io/huggingface/text-generation-inference:3.1.0 --model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
And then you can make requests like
|
And then you can make requests like
|
||||||
@ -121,7 +121,7 @@ curl localhost:8080/v1/chat/completions \
|
|||||||
|
|
||||||
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
||||||
|
|
||||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.0-rocm --model-id $model` instead of the command above.
|
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.0-rocm --model-id $model` instead of the command above.
|
||||||
|
|
||||||
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
|
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
|
||||||
```
|
```
|
||||||
@ -151,7 +151,8 @@ model=meta-llama/Meta-Llama-3.1-8B-Instruct
|
|||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
token=<your cli READ token>
|
token=<your cli READ token>
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.0 --model-id $model
|
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:3.1.0 --model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
### A note on Shared Memory (shm)
|
### A note on Shared Memory (shm)
|
||||||
|
@ -8,7 +8,7 @@ use tracing_opentelemetry::OpenTelemetrySpanExt;
|
|||||||
/// Inject context in the metadata of a gRPC request.
|
/// Inject context in the metadata of a gRPC request.
|
||||||
struct MetadataInjector<'a>(pub &'a mut tonic::metadata::MetadataMap);
|
struct MetadataInjector<'a>(pub &'a mut tonic::metadata::MetadataMap);
|
||||||
|
|
||||||
impl<'a> Injector for MetadataInjector<'a> {
|
impl Injector for MetadataInjector<'_> {
|
||||||
/// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs
|
/// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs
|
||||||
fn set(&mut self, key: &str, value: String) {
|
fn set(&mut self, key: &str, value: String) {
|
||||||
if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) {
|
if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) {
|
||||||
|
@ -1,13 +1,5 @@
|
|||||||
cmake_minimum_required(VERSION 3.20)
|
cmake_minimum_required(VERSION 3.20)
|
||||||
|
|
||||||
if (NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER AND CMAKE_BUILD_TYPE STREQUAL "Debug")
|
|
||||||
find_program(CCACHE_EXECUTABLE "ccache")
|
|
||||||
if (CCACHE_EXECUTABLE)
|
|
||||||
message(STATUS "Using ccache")
|
|
||||||
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_EXECUTABLE}" CACHE PATH "Path to ccache" FORCE)
|
|
||||||
endif ()
|
|
||||||
endif ()
|
|
||||||
|
|
||||||
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
|
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
|
||||||
cmake_policy(SET CMP0135 NEW)
|
cmake_policy(SET CMP0135 NEW)
|
||||||
endif ()
|
endif ()
|
||||||
@ -21,6 +13,7 @@ include(CheckCXXCompilerFlag)
|
|||||||
|
|
||||||
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
|
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
|
||||||
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
|
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
|
||||||
|
option(TGI_TRTLLM_BACKEND_BUILD_USE_LLD "Enable lld linker instead of ld" OFF)
|
||||||
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
|
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
|
||||||
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located")
|
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located")
|
||||||
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located")
|
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located")
|
||||||
@ -28,20 +21,22 @@ set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE ST
|
|||||||
|
|
||||||
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
|
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
|
||||||
find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
|
find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
|
||||||
|
find_package(MPI REQUIRED)
|
||||||
|
|
||||||
#### External dependencies ####
|
#### External dependencies ####
|
||||||
include(cmake/json.cmake)
|
include(cmake/json.cmake)
|
||||||
include(cmake/spdlog.cmake)
|
include(cmake/spdlog.cmake)
|
||||||
include(cmake/trtllm.cmake)
|
include(cmake/trtllm.cmake)
|
||||||
|
|
||||||
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||||
|
set(TGI_TRTLLM_BACKEND_DEBUG ON)
|
||||||
add_compile_definitions(TGI_TRTLLM_BACKEND_DEBUG=1)
|
add_compile_definitions(TGI_TRTLLM_BACKEND_DEBUG=1)
|
||||||
|
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_TRACE)
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
# This attempt to detect if the compiler can emit warning if it can't apply return value optimization from a function
|
if (${TGI_TRTLLM_BACKEND_BUILD_USE_LLD})
|
||||||
check_cxx_compiler_flag("-Wnrvo" COMPILER_SUPPORT_WARNING_ON_NVRO)
|
message(STATUS "Using lld linker")
|
||||||
if(${COMPILER_SUPPORT_WARNING_ON_NVRO})
|
add_link_options("-fuse-ld=lld")
|
||||||
set(CMAKE_CXX_FLAGS "{CMAKE_CXX_FLAGS} -Wnvro")
|
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
# Let's build TRTLLM as part of CMake
|
# Let's build TRTLLM as part of CMake
|
||||||
@ -60,46 +55,63 @@ target_include_directories(tgi_trtllm_backend_impl PRIVATE
|
|||||||
target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
|
target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
|
||||||
target_link_libraries(tgi_trtllm_backend_impl PRIVATE CUDA::cudart CUDA::nvml)
|
target_link_libraries(tgi_trtllm_backend_impl PRIVATE CUDA::cudart CUDA::nvml)
|
||||||
target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog)
|
target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog)
|
||||||
|
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper)
|
||||||
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
|
||||||
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm)
|
|
||||||
else()
|
|
||||||
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapperm)
|
|
||||||
endif ()
|
|
||||||
|
|
||||||
# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back
|
# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back
|
||||||
install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker)
|
install(TARGETS tgi_trtllm_backend_impl)
|
||||||
install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB)
|
install(TARGETS tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker)
|
||||||
|
install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} TYPE LIB)
|
||||||
|
if (NOT ${TGI_TRTLLM_BACKEND_DEBUG})
|
||||||
|
install(FILES ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB)
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
|
||||||
#### Unit Tests ####
|
#### Unit Tests ####
|
||||||
if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
|
if (${TGI_TRTLLM_BACKEND_BUILD_TESTS} AND CMAKE_BUILD_TYPE MATCHES "Debug")
|
||||||
message(STATUS "Building tests")
|
message(STATUS "Building tests")
|
||||||
|
option(TGI_TRTLLM_BACKEND_ENABLE_ASAN "Enable AddressSanitizer")
|
||||||
|
option(TGI_TRTLLM_BACKEND_ENABLE_UBSAN "Enable UndefinedSanitizer")
|
||||||
|
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
Catch2
|
Catch2
|
||||||
URL https://github.com/catchorg/Catch2/archive/refs/tags/v3.7.1.tar.gz
|
URL https://github.com/catchorg/Catch2/archive/refs/tags/v3.7.1.tar.gz
|
||||||
)
|
)
|
||||||
FetchContent_MakeAvailable(Catch2)
|
FetchContent_MakeAvailable(Catch2)
|
||||||
|
|
||||||
|
# This attempt to detect if the compiler can emit warning if it can't apply return value optimization from a function
|
||||||
|
check_cxx_compiler_flag("-Wnrvo" COMPILER_SUPPORT_WARNING_ON_NVRO)
|
||||||
|
if (${COMPILER_SUPPORT_WARNING_ON_NVRO})
|
||||||
|
message(STATUS "Enabling non-NVRO detection")
|
||||||
|
target_compile_options(tgi_trtllm_backend_impl "-Wnvro")
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
cmake_path(GET TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH PARENT_PATH TRTLLM_NVRTC_WRAPPER_PARENT_LIBRARY_PATH)
|
||||||
|
message(STATUS "Adding linking path: ${TRTLLM_NVRTC_WRAPPER_PARENT_LIBRARY_PATH}")
|
||||||
|
|
||||||
add_executable(tgi_trtllm_backend_tests tests/test_hardware.cpp tests/test_backend.cpp)
|
add_executable(tgi_trtllm_backend_tests tests/test_hardware.cpp tests/test_backend.cpp)
|
||||||
|
|
||||||
|
# target_compile_options(tgi_trtllm_backend_tests PRIVATE -Werror)
|
||||||
|
target_link_directories(tgi_trtllm_backend_tests PRIVATE "${TRTLLM_NVRTC_WRAPPER_PARENT_LIBRARY_PATH}")
|
||||||
target_include_directories(tgi_trtllm_backend_tests PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
|
target_include_directories(tgi_trtllm_backend_tests PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
|
||||||
target_include_directories(tgi_trtllm_backend_tests PUBLIC "csrc/")
|
target_include_directories(tgi_trtllm_backend_tests PUBLIC "csrc/")
|
||||||
target_link_libraries(tgi_trtllm_backend_tests PRIVATE ${TRTLLM_LIBS} CUDA::cudart CUDA::nvml)
|
target_link_libraries(tgi_trtllm_backend_tests PRIVATE ${TRTLLM_LIBS} CUDA::cudart CUDA::nvml)
|
||||||
target_link_libraries(tgi_trtllm_backend_tests PUBLIC Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog tgi_trtllm_backend_impl)
|
target_link_libraries(tgi_trtllm_backend_tests PUBLIC Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog tgi_trtllm_backend_impl)
|
||||||
|
target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper)
|
||||||
|
|
||||||
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
if (${TGI_TRTLLM_BACKEND_ENABLE_ASAN})
|
||||||
target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm)
|
message(STATUS "Enabled AddressSanitizer")
|
||||||
else()
|
target_link_options(tgi_trtllm_backend_tests BEFORE PUBLIC -fsanitize=address)
|
||||||
target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapperm)
|
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
if(CMAKE_BUILD_TYPE MATCHES "Debug")
|
if (${TGI_TRTLLM_BACKEND_ENABLE_UBSAN})
|
||||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror -fsanitize=undefined -fsanitize=address")
|
message(STATUS "Enabled UndefinedSanitizer")
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -fsanitize=undefined -fsanitize=address")
|
target_link_options(tgi_trtllm_backend_tests BEFORE PUBLIC -fsanitize=undefined)
|
||||||
target_link_options(tgi_trtllm_backend_tests BEFORE PUBLIC -fsanitize=undefined PUBLIC -fsanitize=address)
|
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
|
install(TARGETS tgi_trtllm_backend_tests)
|
||||||
include(CTest)
|
|
||||||
include(Catch)
|
# list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
|
||||||
catch_discover_tests(tgi_trtllm_backend_tests)
|
# include(CTest)
|
||||||
|
# include(Catch)
|
||||||
|
# catch_discover_tests(tgi_trtllm_backend_tests)
|
||||||
endif ()
|
endif ()
|
||||||
|
@ -7,20 +7,16 @@ homepage.workspace = true
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
#async-stream = "0.3"
|
|
||||||
clap = { version = "4.5", features = ["derive"] }
|
clap = { version = "4.5", features = ["derive"] }
|
||||||
cxx = "1.0"
|
cxx = "1.0"
|
||||||
hashbrown = "0.14"
|
hashbrown = "0.15"
|
||||||
hf-hub = { workspace = true }
|
hf-hub = { workspace = true }
|
||||||
#log = { version = "0.4", features = [] }
|
|
||||||
text-generation-router = { path = "../../router" }
|
text-generation-router = { path = "../../router" }
|
||||||
tokenizers = { workspace = true }
|
tokenizers = { workspace = true }
|
||||||
tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
tokio = { version = "1.43.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
tokio-stream = "0.1.15"
|
tokio-stream = "0.1.17"
|
||||||
thiserror = "1.0.63"
|
thiserror = "1.0.63"
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
#tracing-opentelemetry = "0.25"
|
|
||||||
#tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
|
|
||||||
pyo3 = { workspace = true }
|
pyo3 = { workspace = true }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -3,6 +3,7 @@ use pkg_config;
|
|||||||
use std::env;
|
use std::env;
|
||||||
use std::env::consts::ARCH;
|
use std::env::consts::ARCH;
|
||||||
use std::path::{absolute, PathBuf};
|
use std::path::{absolute, PathBuf};
|
||||||
|
use std::sync::LazyLock;
|
||||||
|
|
||||||
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 1] = ["spdlog"];
|
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 1] = ["spdlog"];
|
||||||
const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST");
|
const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST");
|
||||||
@ -12,12 +13,20 @@ const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX");
|
|||||||
const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR");
|
const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR");
|
||||||
const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR");
|
const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR");
|
||||||
|
|
||||||
|
const IS_GHA_BUILD: LazyLock<bool> = LazyLock::new(|| {
|
||||||
|
option_env!("SCCACHE_GHA_ENABLED").map_or(false, |value| match value.to_lowercase().as_str() {
|
||||||
|
"on" => true,
|
||||||
|
"true" => true,
|
||||||
|
"1" => true,
|
||||||
|
_ => false,
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
// Dependencies
|
// Dependencies
|
||||||
const BACKEND_DEPS: [&str; 2] = ["tgi_trtllm_backend_impl", "tgi_trtllm_backend"];
|
const BACKEND_DEPS: &str = "tgi_trtllm_backend_impl";
|
||||||
const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"];
|
const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"];
|
||||||
const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [
|
const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 4] = [
|
||||||
("dylib", "tensorrt_llm"),
|
("dylib", "tensorrt_llm"),
|
||||||
("static", "tensorrt_llm_executor_static"),
|
|
||||||
("dylib", "tensorrt_llm_nvrtc_wrapper"),
|
("dylib", "tensorrt_llm_nvrtc_wrapper"),
|
||||||
("dylib", "nvinfer_plugin_tensorrt_llm"),
|
("dylib", "nvinfer_plugin_tensorrt_llm"),
|
||||||
("dylib", "decoder_attention"),
|
("dylib", "decoder_attention"),
|
||||||
@ -32,6 +41,48 @@ macro_rules! probe {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_compiler_flag(
|
||||||
|
switch: bool,
|
||||||
|
true_case: &'static str,
|
||||||
|
false_case: &'static str,
|
||||||
|
) -> &'static str {
|
||||||
|
match switch {
|
||||||
|
true => true_case,
|
||||||
|
false => false_case,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_library_architecture() -> &'static str {
|
||||||
|
let os = env::var("CARGO_CFG_TARGET_OS").unwrap();
|
||||||
|
let arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap();
|
||||||
|
let env = env::var("CARGO_CFG_TARGET_ENV").unwrap();
|
||||||
|
|
||||||
|
match os.as_str() {
|
||||||
|
"linux" => {
|
||||||
|
if env != "gnu" {
|
||||||
|
panic!("unsupported linux ABI {env}, only 'gnu' is supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
match arch.as_str() {
|
||||||
|
"x86_64" => "x86_64-linux-gnu",
|
||||||
|
"aarch64" => "aarch64-linux-gnu",
|
||||||
|
_ => panic!("unsupported linux architecture {arch}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"windows" => {
|
||||||
|
if env != "msvc" {
|
||||||
|
panic!("unsupported windows ABI {env}, only 'msvc' is supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
match arch.as_str() {
|
||||||
|
"x86_64" => "x86_64-windows-msvc",
|
||||||
|
_ => panic!("unsupported windows architecture {arch}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => panic!("unsupported OS {os}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) {
|
fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) {
|
||||||
// Build the backend implementation through CMake
|
// Build the backend implementation through CMake
|
||||||
let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
|
let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
|
||||||
@ -54,10 +105,44 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
|
|||||||
.env("OPT_LEVEL", opt_level)
|
.env("OPT_LEVEL", opt_level)
|
||||||
.define("CMAKE_INSTALL_PREFIX", &install_path)
|
.define("CMAKE_INSTALL_PREFIX", &install_path)
|
||||||
.define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc")
|
.define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc")
|
||||||
.define("Python3_ROOT_DIR", "../venv")
|
.define("CMAKE_LIBRARY_ARCHITECTURE", get_library_architecture())
|
||||||
.define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list)
|
.define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list)
|
||||||
|
.define(
|
||||||
|
"TGI_TRTLLM_BACKEND_DEBUG",
|
||||||
|
get_compiler_flag(is_debug, "ON", "OFF"),
|
||||||
|
)
|
||||||
.define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path);
|
.define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path);
|
||||||
|
|
||||||
|
if is_debug || *IS_GHA_BUILD {
|
||||||
|
config.define("TGI_TRTLLM_BACKEND_BUILD_TESTS", "ON");
|
||||||
|
}
|
||||||
|
|
||||||
|
if option_env!("USE_LLD_LINKER").is_some() {
|
||||||
|
println!("cargo:warning=Using lld linker");
|
||||||
|
config.define("TGI_TRTLLM_BACKEND_BUILD_USE_LLD", "ON");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_debug && option_env!("ENABLE_ASAN").is_some()) || *IS_GHA_BUILD {
|
||||||
|
println!("cargo:warning=Enabling Address Sanitizer");
|
||||||
|
config.define("TGI_TRTLLM_BACKEND_ENABLE_ASAN", "ON");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_debug && option_env!("ENABLE_UBSAN").is_some()) || *IS_GHA_BUILD {
|
||||||
|
println!("cargo:warning=Enabling Undefined Sanitizer");
|
||||||
|
config.define("TGI_TRTLLM_BACKEND_ENABLE_UBSAN", "ON");
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(nvcc_host_compiler) = option_env!("CMAKE_CUDA_HOST_COMPILER") {
|
||||||
|
config.define("CMAKE_CUDA_HOST_COMPILER", nvcc_host_compiler);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(wrapper) = option_env!("RUSTC_WRAPPER") {
|
||||||
|
println!("cargo:warning=Using caching tool: {wrapper}");
|
||||||
|
config.define("CMAKE_C_COMPILER_LAUNCHER", wrapper);
|
||||||
|
config.define("CMAKE_CXX_COMPILER_LAUNCHER", wrapper);
|
||||||
|
config.define("CMAKE_CUDA_COMPILER_LAUNCHER", wrapper);
|
||||||
|
}
|
||||||
|
|
||||||
// Allow to override which Python to use ...
|
// Allow to override which Python to use ...
|
||||||
if let Some(python3) = option_env!("Python3_EXECUTABLE") {
|
if let Some(python3) = option_env!("Python3_EXECUTABLE") {
|
||||||
config.define("Python3_EXECUTABLE", python3);
|
config.define("Python3_EXECUTABLE", python3);
|
||||||
@ -78,23 +163,18 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Emit linkage information from the artifacts we just built
|
// Emit linkage information from the artifacts we just built
|
||||||
let install_lib_path = install_path.join("lib");
|
for path in ["lib", "lib64"] {
|
||||||
|
let install_lib_path = install_path.join(path);
|
||||||
println!(
|
println!(
|
||||||
r"cargo:warning=Adding link search path: {}",
|
r"cargo:warning=Adding link search path: {}",
|
||||||
install_lib_path.display()
|
install_lib_path.display()
|
||||||
);
|
);
|
||||||
println!(r"cargo:rustc-link-search={}", install_lib_path.display());
|
println!(r"cargo:rustc-link-search={}", install_lib_path.display());
|
||||||
|
}
|
||||||
(PathBuf::from(install_path), deps_folder)
|
(PathBuf::from(install_path), deps_folder)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) {
|
fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) {
|
||||||
let ndebug = match is_debug {
|
|
||||||
true => "1",
|
|
||||||
false => "0",
|
|
||||||
};
|
|
||||||
|
|
||||||
CFG.include_prefix = "backends/trtllm";
|
CFG.include_prefix = "backends/trtllm";
|
||||||
cxx_build::bridge("src/lib.rs")
|
cxx_build::bridge("src/lib.rs")
|
||||||
.static_flag(true)
|
.static_flag(true)
|
||||||
@ -106,7 +186,10 @@ fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) {
|
|||||||
.include("/usr/local/tensorrt/include")
|
.include("/usr/local/tensorrt/include")
|
||||||
.include("csrc/")
|
.include("csrc/")
|
||||||
.file("csrc/ffi.hpp")
|
.file("csrc/ffi.hpp")
|
||||||
.define("TGI_TRTLLM_BACKEND_DEBUG", ndebug)
|
.define(
|
||||||
|
"TGI_TRTLLM_BACKEND_DEBUG",
|
||||||
|
get_compiler_flag(is_debug, "ON", "OFF"),
|
||||||
|
)
|
||||||
.compile("tgi_trtllm_backend");
|
.compile("tgi_trtllm_backend");
|
||||||
|
|
||||||
println!("cargo:rerun-if-changed=CMakeLists.txt");
|
println!("cargo:rerun-if-changed=CMakeLists.txt");
|
||||||
@ -125,6 +208,7 @@ fn main() {
|
|||||||
let build_profile = env::var("PROFILE").unwrap();
|
let build_profile = env::var("PROFILE").unwrap();
|
||||||
let (is_debug, opt_level) = match build_profile.as_ref() {
|
let (is_debug, opt_level) = match build_profile.as_ref() {
|
||||||
"debug" => (true, "0"),
|
"debug" => (true, "0"),
|
||||||
|
"dev" => (true, "0"),
|
||||||
_ => (false, "3"),
|
_ => (false, "3"),
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -161,7 +245,5 @@ fn main() {
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Backend
|
// Backend
|
||||||
BACKEND_DEPS.iter().for_each(|name| {
|
println!("cargo:rustc-link-lib=static={}", &BACKEND_DEPS);
|
||||||
println!("cargo:rustc-link-lib=static={}", name);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
@ -4,14 +4,14 @@ set(SPDLOG_FMT_EXTERNAL OFF)
|
|||||||
|
|
||||||
# Define the level at which SPDLOG_ compilation level is defined
|
# Define the level at which SPDLOG_ compilation level is defined
|
||||||
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
||||||
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG)
|
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_TRACE)
|
||||||
else ()
|
else ()
|
||||||
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO)
|
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG)
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
fetchcontent_declare(
|
fetchcontent_declare(
|
||||||
spdlog
|
spdlog
|
||||||
# DOWNLOAD_EXTRACT_TIMESTAMP
|
# DOWNLOAD_EXTRACT_TIMESTAMP
|
||||||
URL https://github.com/gabime/spdlog/archive/refs/tags/v1.14.1.tar.gz
|
URL https://github.com/gabime/spdlog/archive/refs/tags/v1.15.0.tar.gz
|
||||||
)
|
)
|
||||||
fetchcontent_makeavailable(spdlog)
|
fetchcontent_makeavailable(spdlog)
|
||||||
|
@ -14,19 +14,21 @@ message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
|||||||
set(ENABLE_UCX OFF)
|
set(ENABLE_UCX OFF)
|
||||||
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
||||||
set(FAST_BUILD ON)
|
set(FAST_BUILD ON)
|
||||||
set(NVTX_DISABLE OFF)
|
set(NVTX_DISABLE ON)
|
||||||
|
set(INDEX_RANGE_CHECK ON)
|
||||||
else ()
|
else ()
|
||||||
set(FAST_BUILD OFF)
|
set(FAST_BUILD OFF)
|
||||||
set(FAST_MATH ON)
|
set(FAST_MATH ON)
|
||||||
set(NVTX_DISABLE ON)
|
set(NVTX_DISABLE OFF)
|
||||||
|
set(INDEX_RANGE_CHECK OFF)
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
find_package(Python3 REQUIRED Interpreter)
|
find_package(Python3 REQUIRED Interpreter)
|
||||||
|
|
||||||
fetchcontent_declare(
|
fetchcontent_declare(
|
||||||
trtllm
|
trtllm
|
||||||
GIT_REPOSITORY https://github.com/huggingface/TensorRT-LLM.git
|
GIT_REPOSITORY https://github.com/nvidia/TensorRT-LLM.git
|
||||||
GIT_TAG 1bb9ca4688805444f203647674bac1d7219d0579
|
GIT_TAG v0.16.0
|
||||||
GIT_SHALLOW ON
|
GIT_SHALLOW ON
|
||||||
DOWNLOAD_EXTRACT_TIMESTAMP
|
DOWNLOAD_EXTRACT_TIMESTAMP
|
||||||
)
|
)
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
#include <ranges>
|
#include <ranges>
|
||||||
|
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
#include <spdlog/spdlog.h>
|
|
||||||
|
|
||||||
#include "backend.hpp"
|
#include "backend.hpp"
|
||||||
#include "hardware.hpp"
|
#include "hardware.hpp"
|
||||||
@ -17,7 +16,8 @@ namespace huggingface::tgi::backends::trtllm {
|
|||||||
if (world_size > 1) {
|
if (world_size > 1) {
|
||||||
SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
|
SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
|
||||||
mode = tle::CommunicationMode::kORCHESTRATOR;
|
mode = tle::CommunicationMode::kORCHESTRATOR;
|
||||||
orchestratorConfig = std::make_optional<tle::OrchestratorConfig>(true, executor_worker_path_, nullptr, true);
|
orchestratorConfig = std::make_optional<tle::OrchestratorConfig>(true, executor_worker_path_, nullptr,
|
||||||
|
true);
|
||||||
} else {
|
} else {
|
||||||
SPDLOG_INFO("Detected single engine deployment, using leader mode");
|
SPDLOG_INFO("Detected single engine deployment, using leader mode");
|
||||||
}
|
}
|
||||||
@ -51,13 +51,14 @@ namespace huggingface::tgi::backends::trtllm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::expected<request_id_t, backend_error_t>
|
std::expected<request_id_t, backend_error_t>
|
||||||
backend_t::submit(std::span<const token_id_t> token_ids, const generation_params_t generation_params, const sampling_params_t sampling_params) noexcept {
|
backend_t::submit(std::span<const token_id_t> token_ids, const generation_params_t g_params,
|
||||||
SPDLOG_DEBUG("Submitting {:d} tokens to the executor for scheduling ({}, {})", token_ids.size(), generation_params, sampling_params);
|
const sampling_params_t s_params) noexcept {
|
||||||
|
SPDLOG_DEBUG("Submit {:d} tokens for scheduling ({}, {})", token_ids.size(), g_params, s_params);
|
||||||
return executor_.enqueueRequest(tle::Request{
|
return executor_.enqueueRequest(tle::Request{
|
||||||
{token_ids.begin(), token_ids.end()}, // Making actual copy of the tokens
|
{token_ids.begin(), token_ids.end()}, // Making actual copy of the tokens
|
||||||
static_cast<tle::SizeType32>(generation_params.max_new_tokens),
|
static_cast<tle::SizeType32>(g_params.max_new_tokens),
|
||||||
true,
|
true,
|
||||||
(tle::SamplingConfig) sampling_params,
|
(tle::SamplingConfig) s_params,
|
||||||
tle::OutputConfig{ /* returnLogProbs= */ true},
|
tle::OutputConfig{ /* returnLogProbs= */ true},
|
||||||
std::nullopt,
|
std::nullopt,
|
||||||
std::nullopt,
|
std::nullopt,
|
||||||
|
@ -28,9 +28,53 @@ namespace huggingface::tgi::backends::trtllm {
|
|||||||
|
|
||||||
#include "backends/trtllm/src/lib.rs.h"
|
#include "backends/trtllm/src/lib.rs.h"
|
||||||
|
|
||||||
|
|
||||||
namespace huggingface::tgi::backends::trtllm {
|
namespace huggingface::tgi::backends::trtllm {
|
||||||
std::once_flag backend_initialized_flag;
|
std::once_flag backend_initialized_flag;
|
||||||
|
|
||||||
|
constexpr finish_reason_t as_finish_reason_t(const tle::FinishReason reason) noexcept {
|
||||||
|
switch (reason) {
|
||||||
|
case tle::FinishReason::kNOT_FINISHED:
|
||||||
|
return finish_reason_t::kNOT_FINISHED;
|
||||||
|
case tle::FinishReason::kSTOP_WORDS:
|
||||||
|
return finish_reason_t::kSTOP_WORDS;
|
||||||
|
case tle::FinishReason::kEND_ID:
|
||||||
|
return finish_reason_t::kEND_ID;
|
||||||
|
case tle::FinishReason::kLENGTH:
|
||||||
|
return finish_reason_t::kLENGTH;
|
||||||
|
default:
|
||||||
|
std::unreachable();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static auto as_generation_step = [](const tle::Response &r) {
|
||||||
|
const auto reqId = r.getRequestId();
|
||||||
|
if (!r.hasError()) [[likely]] {
|
||||||
|
const auto result = r.getResult();
|
||||||
|
const auto logits = result.logProbs.value()[0];
|
||||||
|
return generation_step_t{
|
||||||
|
reqId,
|
||||||
|
static_cast<uint32_t>(result.outputTokenIds[0][0]),
|
||||||
|
logits.back(),
|
||||||
|
result.isFinal,
|
||||||
|
as_finish_reason_t(result.finishReasons[0]),
|
||||||
|
false,
|
||||||
|
std::string()
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
return generation_step_t{
|
||||||
|
reqId,
|
||||||
|
0,
|
||||||
|
0.0,
|
||||||
|
true,
|
||||||
|
finish_reason_t::kNOT_FINISHED,
|
||||||
|
true,
|
||||||
|
std::move(r.getErrorMsg())
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
class tensorrt_llm_backend_t {
|
class tensorrt_llm_backend_t {
|
||||||
private:
|
private:
|
||||||
backend_t inner_;
|
backend_t inner_;
|
||||||
@ -39,9 +83,7 @@ namespace huggingface::tgi::backends::trtllm {
|
|||||||
tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path)
|
tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path)
|
||||||
: inner_(engine_folder, executor_worker_path) {}
|
: inner_(engine_folder, executor_worker_path) {}
|
||||||
|
|
||||||
size_t num_tokens_ready() const noexcept {
|
size_t num_tokens_ready() const noexcept { return inner_.num_tokens_ready(); }
|
||||||
return inner_.num_tokens_ready();
|
|
||||||
}
|
|
||||||
|
|
||||||
request_id_t submit(
|
request_id_t submit(
|
||||||
rust::Slice<const uint32_t> tokens,
|
rust::Slice<const uint32_t> tokens,
|
||||||
@ -78,41 +120,25 @@ namespace huggingface::tgi::backends::trtllm {
|
|||||||
const auto responses = inner_.pull_tokens();
|
const auto responses = inner_.pull_tokens();
|
||||||
|
|
||||||
SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size());
|
SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size());
|
||||||
// Transform tle::Response to GenerationStep
|
|
||||||
auto steps = std::make_unique<std::vector<generation_step_t>>();
|
// Transform tle::Response to generation_step_t
|
||||||
std::ranges::transform(responses.begin(), responses.end(), std::back_inserter(*steps), [](const tle::Response &r) {
|
#ifdef __cpp_lib_ranges_to_container
|
||||||
const auto reqId = r.getRequestId();
|
auto steps = responses | std::views::transform(as_generation_step) | std::ranges::to<std::vector>();
|
||||||
if (!r.hasError()) [[likely]] {
|
#else
|
||||||
const auto result = r.getResult();
|
auto steps = std::vector<generation_step_t>();
|
||||||
return generation_step_t{
|
steps.reserve(responses.size());
|
||||||
reqId,
|
std::transform(responses.begin(), responses.end(), std::back_inserter(steps), as_generation_step);
|
||||||
static_cast<uint32_t>(result.outputTokenIds[0][0]),
|
#endif
|
||||||
result.logProbs.value()[0][0],
|
return std::make_unique<std::vector<generation_step_t>>(steps);
|
||||||
result.isFinal,
|
|
||||||
false,
|
|
||||||
std::string()
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
return generation_step_t{
|
|
||||||
reqId,
|
|
||||||
0,
|
|
||||||
0.0,
|
|
||||||
true,
|
|
||||||
true,
|
|
||||||
std::move(r.getErrorMsg())
|
|
||||||
};
|
|
||||||
}
|
|
||||||
});
|
|
||||||
return steps;
|
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
return std::make_unique<std::vector<generation_step_t>>();
|
return std::make_unique<std::vector<generation_step_t>>();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void cancel(request_id_t requestId) noexcept {
|
void cancel(request_id_t request_id) noexcept {
|
||||||
SPDLOG_DEBUG("[FFI] cancelling request {:d}", requestId);
|
SPDLOG_DEBUG("[FFI] cancelling request {:d}", request_id);
|
||||||
inner_.cancel(requestId);
|
inner_.cancel(request_id);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -151,11 +177,14 @@ namespace huggingface::tgi::backends::trtllm {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<tensorrt_llm_backend_t> create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) {
|
std::unique_ptr<tensorrt_llm_backend_t>
|
||||||
|
create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) {
|
||||||
std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend);
|
std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend);
|
||||||
return std::make_unique<tensorrt_llm_backend_t>(
|
return std::make_unique<tensorrt_llm_backend_t>(
|
||||||
std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()), std::filesystem::path::format::auto_format),
|
std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()),
|
||||||
std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()), std::filesystem::path::format::auto_format)
|
std::filesystem::path::format::auto_format),
|
||||||
|
std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()),
|
||||||
|
std::filesystem::path::format::auto_format)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
set -ex
|
set -ex
|
||||||
|
|
||||||
TRT_VER_BASE="10.6.0"
|
TRT_VER_BASE="10.7.0"
|
||||||
TRT_VER_FULL="${TRT_VER_BASE}.26"
|
TRT_VER_FULL="${TRT_VER_BASE}.23"
|
||||||
CUDA_VER="12.6"
|
CUDA_VER="12.6"
|
||||||
CUDNN_VER="9.5.0.50-1"
|
CUDNN_VER="9.5.0.50-1"
|
||||||
NCCL_VER="2.22.3-1+cuda12.6"
|
NCCL_VER="2.22.3-1+cuda12.6"
|
||||||
|
51
backends/trtllm/scripts/setup_sccache.py
Normal file
51
backends/trtllm/scripts/setup_sccache.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
AWS_S3_CACHING_VARIABLES = {
|
||||||
|
"AWS_ACCESS_KEY_ID": "aws_access_key_id",
|
||||||
|
"AWS_SECRET_ACCESS_KEY": "aws_secret_access_key",
|
||||||
|
"AWS_SESSION_TOKEN": "aws_session_token",
|
||||||
|
"SCCACHE_REGION": "s3_region",
|
||||||
|
"SCCACHE_BUCKET": "s3_bucket_name",
|
||||||
|
}
|
||||||
|
|
||||||
|
ALL_CACHING_STORAGE_VARIABLES = {"AWS_S3_CACHING_VARIABLES"}
|
||||||
|
|
||||||
|
|
||||||
|
def setup_sccache_locally():
|
||||||
|
from os import environ
|
||||||
|
|
||||||
|
print("Setting up Local Caching Layer")
|
||||||
|
for target in ALL_CACHING_STORAGE_VARIABLES:
|
||||||
|
for envvar in globals()[target].keys():
|
||||||
|
if envvar in environ:
|
||||||
|
print(f"Deleted {envvar} from environment variables")
|
||||||
|
del environ[envvar]
|
||||||
|
|
||||||
|
|
||||||
|
def setup_sccache_for_s3():
|
||||||
|
from os import environ
|
||||||
|
|
||||||
|
print("Setting up AWS S3 Caching Layer")
|
||||||
|
for envvar in AWS_S3_CACHING_VARIABLES.keys():
|
||||||
|
if envvar not in environ or not environ[envvar] or len(environ[envvar]) == 0:
|
||||||
|
print(f"Missing definition for environment variable {envvar}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = ArgumentParser("TensorRT-LLM Build Caching Setup")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--is-gha-build",
|
||||||
|
type=str,
|
||||||
|
default="FALSE",
|
||||||
|
help="Indicate if the build is from Github Actions",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse args
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.is_gha_build = args.is_gha_build.lower() in {"on", "true", "1"}
|
||||||
|
|
||||||
|
if args.is_gha_build:
|
||||||
|
setup_sccache_for_s3()
|
||||||
|
else:
|
||||||
|
setup_sccache_locally()
|
@ -6,6 +6,26 @@ mod utils;
|
|||||||
|
|
||||||
#[cxx::bridge(namespace = "huggingface::tgi::backends::trtllm")]
|
#[cxx::bridge(namespace = "huggingface::tgi::backends::trtllm")]
|
||||||
mod ffi {
|
mod ffi {
|
||||||
|
#[cxx_name = "finish_reason_t"]
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub enum FinishReason {
|
||||||
|
/// The request is not finished.
|
||||||
|
#[cxx_name = "kNOT_FINISHED"]
|
||||||
|
NotFinished = 0u8,
|
||||||
|
|
||||||
|
/// The request finished because the end id was generated.
|
||||||
|
#[cxx_name = "kEND_ID"]
|
||||||
|
EndTokenId = 1u8,
|
||||||
|
|
||||||
|
/// The request finished because a stop word was generated.
|
||||||
|
#[cxx_name = "kSTOP_WORDS"]
|
||||||
|
StopWords = 2u8,
|
||||||
|
|
||||||
|
/// The request finished because the maximum number of tokens was reached.
|
||||||
|
#[cxx_name = "kLENGTH"]
|
||||||
|
MaxLength = 3u8,
|
||||||
|
}
|
||||||
|
|
||||||
/// Struct used as shared type between rust and C++ to represent the result
|
/// Struct used as shared type between rust and C++ to represent the result
|
||||||
/// of a single decoding iteration
|
/// of a single decoding iteration
|
||||||
#[cxx_name = "generation_step_t"]
|
#[cxx_name = "generation_step_t"]
|
||||||
@ -15,6 +35,7 @@ mod ffi {
|
|||||||
token_id: u32,
|
token_id: u32,
|
||||||
log_prob: f32,
|
log_prob: f32,
|
||||||
is_final: bool,
|
is_final: bool,
|
||||||
|
finish_reason: FinishReason,
|
||||||
has_error: bool,
|
has_error: bool,
|
||||||
error_msg: String,
|
error_msg: String,
|
||||||
}
|
}
|
||||||
@ -66,3 +87,17 @@ mod ffi {
|
|||||||
fn cancel(self: Pin<&mut TensorRtLlmBackendImpl>, request_id: u64);
|
fn cancel(self: Pin<&mut TensorRtLlmBackendImpl>, request_id: u64);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
use ffi::FinishReason;
|
||||||
|
use text_generation_router::FinishReason as InferFinishReason;
|
||||||
|
|
||||||
|
impl From<FinishReason> for InferFinishReason {
|
||||||
|
fn from(reason: FinishReason) -> Self {
|
||||||
|
match reason {
|
||||||
|
FinishReason::StopWords => InferFinishReason::StopSequence,
|
||||||
|
FinishReason::MaxLength => InferFinishReason::Length,
|
||||||
|
FinishReason::EndTokenId => InferFinishReason::EndOfSequenceToken,
|
||||||
|
_ => panic!("Cannot convert {reason:?} to text_generation_router::FinishReason"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -18,10 +18,12 @@ use text_generation_router::validation::ValidationError::{
|
|||||||
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
|
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
|
||||||
};
|
};
|
||||||
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
|
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
|
||||||
use text_generation_router::{FinishReason, Token};
|
use text_generation_router::Token;
|
||||||
|
|
||||||
use crate::errors::TensorRtLlmBackendError;
|
use crate::errors::TensorRtLlmBackendError;
|
||||||
use crate::ffi::{create_backend_from_engine_folder, GenerationStep, TensorRtLlmBackendImpl};
|
use crate::ffi::{
|
||||||
|
create_backend_from_engine_folder, FinishReason, GenerationStep, TensorRtLlmBackendImpl,
|
||||||
|
};
|
||||||
use crate::utils::first_line;
|
use crate::utils::first_line;
|
||||||
|
|
||||||
type InferResult<T> = Result<T, InferError>;
|
type InferResult<T> = Result<T, InferError>;
|
||||||
@ -40,6 +42,7 @@ struct DecodedToken {
|
|||||||
id: u32,
|
id: u32,
|
||||||
log_prob: f32,
|
log_prob: f32,
|
||||||
is_final: bool,
|
is_final: bool,
|
||||||
|
finish_reason: FinishReason,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
|
impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
|
||||||
@ -51,6 +54,7 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
|
|||||||
id: step.token_id,
|
id: step.token_id,
|
||||||
log_prob: step.log_prob,
|
log_prob: step.log_prob,
|
||||||
is_final: step.is_final,
|
is_final: step.is_final,
|
||||||
|
finish_reason: step.finish_reason,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
Err(GenerationError(step.error_msg.clone()))
|
Err(GenerationError(step.error_msg.clone()))
|
||||||
@ -192,7 +196,7 @@ fn post_process_decoded_token(
|
|||||||
let generated_text = GeneratedText {
|
let generated_text = GeneratedText {
|
||||||
text: text.unwrap(),
|
text: text.unwrap(),
|
||||||
generated_tokens: ctx.tokens.len() as u32,
|
generated_tokens: ctx.tokens.len() as u32,
|
||||||
finish_reason: FinishReason::EndOfSequenceToken, // TODO : Map FinishReason
|
finish_reason: decoded_token.finish_reason.into(),
|
||||||
seed: None,
|
seed: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -336,4 +340,8 @@ impl Backend for TensorRtLlmBackendV2 {
|
|||||||
async fn health(&self, _: bool) -> bool {
|
async fn health(&self, _: bool) -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"TensorRT-LLM"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,7 @@ use text_generation_router::server::{
|
|||||||
get_hub_model_info, legacy_tokenizer_handle, py_resolve_tokenizer,
|
get_hub_model_info, legacy_tokenizer_handle, py_resolve_tokenizer,
|
||||||
};
|
};
|
||||||
use text_generation_router::usage_stats::UsageStatsLevel;
|
use text_generation_router::usage_stats::UsageStatsLevel;
|
||||||
use text_generation_router::{server, HubTokenizerConfig, Tokenizer};
|
use text_generation_router::{server, Tokenizer};
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -67,11 +67,7 @@ struct Args {
|
|||||||
payload_limit: usize,
|
payload_limit: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_tokenizer(
|
async fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option<Tokenizer> {
|
||||||
tokenizer_name: &str,
|
|
||||||
tokenizer_config_path: Option<&str>,
|
|
||||||
revision: Option<&str>,
|
|
||||||
) -> Option<Tokenizer> {
|
|
||||||
// Parse Huggingface hub token
|
// Parse Huggingface hub token
|
||||||
let authorization_token = std::env::var("HF_TOKEN")
|
let authorization_token = std::env::var("HF_TOKEN")
|
||||||
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
||||||
@ -182,19 +178,6 @@ async fn get_tokenizer(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
|
||||||
// let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
|
||||||
// {
|
|
||||||
// HubTokenizerConfig::from_file(filename)
|
|
||||||
// } else {
|
|
||||||
// tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
|
||||||
// };
|
|
||||||
|
|
||||||
// let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
|
||||||
// tracing::warn!("Could not find tokenizer config locally and no API specified");
|
|
||||||
// HubTokenizerConfig::default()
|
|
||||||
// });
|
|
||||||
|
|
||||||
let tokenizer: Tokenizer = {
|
let tokenizer: Tokenizer = {
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
pyo3::Python::with_gil(|py| -> PyResult<()> {
|
pyo3::Python::with_gil(|py| -> PyResult<()> {
|
||||||
@ -292,11 +275,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create the backend
|
// Create the backend
|
||||||
match get_tokenizer(
|
match get_tokenizer(&tokenizer_name, revision.as_deref())
|
||||||
&tokenizer_name,
|
|
||||||
tokenizer_config_path.as_deref(),
|
|
||||||
revision.as_deref(),
|
|
||||||
)
|
|
||||||
.await
|
.await
|
||||||
.expect("Failed to retrieve tokenizer implementation")
|
.expect("Failed to retrieve tokenizer implementation")
|
||||||
{
|
{
|
||||||
|
@ -8,13 +8,13 @@
|
|||||||
|
|
||||||
#include "backend.hpp"
|
#include "backend.hpp"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
using namespace huggingface::tgi::backends::trtllm;
|
using namespace huggingface::tgi::backends::trtllm;
|
||||||
|
|
||||||
TEST_CASE("parse generation_config.json all set", "[generation_config_t]")
|
TEST_CASE("parse generation_config.json all set", "[generation_config_t]")
|
||||||
{
|
{
|
||||||
const json config_j = {{"temperature", 0.6}, {"top_p", 0.95}, {"eos_token_id", {1,2,3}}};
|
const json config_j = {{"temperature", 0.6},
|
||||||
|
{"top_p", 0.95},
|
||||||
|
{"eos_token_id", {1, 2, 3}}};
|
||||||
const auto generation_config = generation_config_t(config_j);
|
const auto generation_config = generation_config_t(config_j);
|
||||||
|
|
||||||
REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(0.6, 1e-6));
|
REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(0.6, 1e-6));
|
||||||
@ -24,8 +24,9 @@ TEST_CASE("parse generation_config.json all set", "[generation_config_t]")
|
|||||||
REQUIRE_FALSE(generation_config.stop_words.empty());
|
REQUIRE_FALSE(generation_config.stop_words.empty());
|
||||||
REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());
|
REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());
|
||||||
|
|
||||||
for (auto [lhs, rhs] : std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1}, {2}, {3}}))
|
for (auto [lhs, rhs]: std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1},
|
||||||
{
|
{2},
|
||||||
|
{3}})) {
|
||||||
// Currently we do not support multi-tokens stop words
|
// Currently we do not support multi-tokens stop words
|
||||||
REQUIRE(lhs.size() == 1);
|
REQUIRE(lhs.size() == 1);
|
||||||
REQUIRE(rhs.size() == 1);
|
REQUIRE(rhs.size() == 1);
|
||||||
@ -44,8 +45,9 @@ TEST_CASE("parse generation_config.json default", "[generation_config_t]")
|
|||||||
REQUIRE_FALSE(generation_config.stop_words.empty());
|
REQUIRE_FALSE(generation_config.stop_words.empty());
|
||||||
REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());
|
REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());
|
||||||
|
|
||||||
for (auto [lhs, rhs] : std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1}, {2}, {3}}))
|
for (auto [lhs, rhs]: std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1},
|
||||||
{
|
{2},
|
||||||
|
{3}})) {
|
||||||
// Currently we do not support multi-tokens stop words
|
// Currently we do not support multi-tokens stop words
|
||||||
REQUIRE(lhs.size() == 1);
|
REQUIRE(lhs.size() == 1);
|
||||||
REQUIRE(rhs.size() == 1);
|
REQUIRE(rhs.size() == 1);
|
||||||
|
@ -108,6 +108,10 @@ impl Backend for BackendV2 {
|
|||||||
fn start_health(&self) -> bool {
|
fn start_health(&self) -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"tgi-v2"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Batching logic
|
/// Batching logic
|
||||||
|
@ -213,8 +213,7 @@ impl State {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Pad prefill_token_budget to be a multiple of block size
|
// Pad prefill_token_budget to be a multiple of block size
|
||||||
let prefill_token_budget =
|
let prefill_token_budget = prefill_token_budget.div_ceil(self.block_size) * self.block_size;
|
||||||
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
|
|
||||||
|
|
||||||
// Create span for this batch to add context to inference calls
|
// Create span for this batch to add context to inference calls
|
||||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||||
@ -245,9 +244,8 @@ impl State {
|
|||||||
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
|
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
|
||||||
} else {
|
} else {
|
||||||
// pad to block size
|
// pad to block size
|
||||||
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
|
prefill_tokens +=
|
||||||
/ self.block_size)
|
entry.request.input_length.div_ceil(self.block_size) * self.block_size;
|
||||||
* self.block_size;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.requires_padding {
|
if self.requires_padding {
|
||||||
@ -262,8 +260,7 @@ impl State {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// pad to block size
|
// pad to block size
|
||||||
decode_tokens +=
|
decode_tokens += max_new_tokens.div_ceil(self.block_size) * self.block_size;
|
||||||
((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if prefill_tokens > prefill_token_budget
|
if prefill_tokens > prefill_token_budget
|
||||||
|
@ -115,6 +115,10 @@ impl Backend for BackendV3 {
|
|||||||
fn start_health(&self) -> bool {
|
fn start_health(&self) -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"tgi-v3"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Batching logic
|
/// Batching logic
|
||||||
|
@ -165,13 +165,13 @@ impl Allocator for SimpleAllocator {
|
|||||||
let (tokens, repeats) = match self.window_size {
|
let (tokens, repeats) = match self.window_size {
|
||||||
None => (tokens, 1),
|
None => (tokens, 1),
|
||||||
Some(window_size) => {
|
Some(window_size) => {
|
||||||
let repeats = (tokens + window_size - 1) / window_size;
|
let repeats = tokens.div_ceil(window_size);
|
||||||
let tokens = core::cmp::min(tokens, window_size);
|
let tokens = core::cmp::min(tokens, window_size);
|
||||||
(tokens, repeats as usize)
|
(tokens, repeats as usize)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
// Pad to a multiple of block size
|
// Pad to a multiple of block size
|
||||||
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
|
let required_blocks = tokens.div_ceil(self.block_size);
|
||||||
(required_blocks, repeats)
|
(required_blocks, repeats)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -257,8 +257,7 @@ impl State {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Pad prefill_token_budget to be a multiple of block size
|
// Pad prefill_token_budget to be a multiple of block size
|
||||||
let prefill_token_budget =
|
let prefill_token_budget = prefill_token_budget.div_ceil(self.block_size) * self.block_size;
|
||||||
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
|
|
||||||
|
|
||||||
// Create span for this batch to add context to inference calls
|
// Create span for this batch to add context to inference calls
|
||||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||||
|
@ -103,7 +103,7 @@ impl Allocator for RadixAllocator {
|
|||||||
let prefix_len = blocks.len() * self.block_size as usize;
|
let prefix_len = blocks.len() * self.block_size as usize;
|
||||||
let suffix_len = tokens - prefix_len as u32;
|
let suffix_len = tokens - prefix_len as u32;
|
||||||
|
|
||||||
let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size;
|
let suffix_blocks = suffix_len.div_ceil(self.block_size);
|
||||||
|
|
||||||
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
|
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@
|
|||||||
"name": "Apache 2.0",
|
"name": "Apache 2.0",
|
||||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||||
},
|
},
|
||||||
"version": "3.0.2-dev0"
|
"version": "3.1.1-dev0"
|
||||||
},
|
},
|
||||||
"paths": {
|
"paths": {
|
||||||
"/": {
|
"/": {
|
||||||
|
@ -13,6 +13,8 @@
|
|||||||
title: Using TGI with Intel Gaudi
|
title: Using TGI with Intel Gaudi
|
||||||
- local: installation_inferentia
|
- local: installation_inferentia
|
||||||
title: Using TGI with AWS Inferentia
|
title: Using TGI with AWS Inferentia
|
||||||
|
- local: installation_tpu
|
||||||
|
title: Using TGI with Google TPUs
|
||||||
- local: installation_intel
|
- local: installation_intel
|
||||||
title: Using TGI with Intel GPUs
|
title: Using TGI with Intel GPUs
|
||||||
- local: installation
|
- local: installation
|
||||||
|
@ -4,8 +4,13 @@ The NVIDIA TensorRT-LLM (TRTLLM) backend is a high-performance backend for LLMs
|
|||||||
that uses NVIDIA's TensorRT library for inference acceleration.
|
that uses NVIDIA's TensorRT library for inference acceleration.
|
||||||
It makes use of specific optimizations for NVIDIA GPUs, such as custom kernels.
|
It makes use of specific optimizations for NVIDIA GPUs, such as custom kernels.
|
||||||
|
|
||||||
To use the TRTLLM backend you need to compile `engines` for the models you want to use.
|
To use the TRTLLM backend **you need to compile** `engines` for the models you want to use.
|
||||||
Each `engine` must be compiled on the same GPU architecture that you will use for inference.
|
Each `engine` must be compiled for a given set of:
|
||||||
|
- GPU architecture that you will use for inference (e.g. A100, L40, etc.)
|
||||||
|
- Maximum batch size
|
||||||
|
- Maximum input length
|
||||||
|
- Maximum output length
|
||||||
|
- Maximum beams width
|
||||||
|
|
||||||
## Supported models
|
## Supported models
|
||||||
|
|
||||||
@ -19,63 +24,159 @@ want to use.
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
MODEL_NAME="meta-llama/Llama-3.1-8B-Instruct"
|
MODEL_NAME="meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
DESTINATION="/tmp/engines/$MODEL_NAME"
|
||||||
# Install huggingface_cli
|
HF_TOKEN="hf_xxx"
|
||||||
python -m pip install huggingface-cli[hf_transfer]
|
|
||||||
|
|
||||||
# Login to the Hugging Face Hub
|
|
||||||
huggingface-cli login
|
|
||||||
|
|
||||||
# Create a directory to store the model
|
|
||||||
mkdir -p /tmp/models/$MODEL_NAME
|
|
||||||
|
|
||||||
# Create a directory to store the compiled engine
|
|
||||||
mkdir -p /tmp/engines/$MODEL_NAME
|
|
||||||
|
|
||||||
# Download the model
|
|
||||||
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download --local-dir /tmp/models/$MODEL_NAME $MODEL_NAME
|
|
||||||
|
|
||||||
# Compile the engine using Optimum-NVIDIA
|
# Compile the engine using Optimum-NVIDIA
|
||||||
|
# This will create a compiled engine in the /tmp/engines/meta-llama/Llama-3.1-8B-Instruct
|
||||||
|
# directory for 1 GPU
|
||||||
docker run \
|
docker run \
|
||||||
--rm \
|
--rm \
|
||||||
-it \
|
-it \
|
||||||
--gpus=1 \
|
--gpus=1 \
|
||||||
-v /tmp/models/$MODEL_NAME:/model \
|
--shm-size=1g \
|
||||||
-v /tmp/engines/$MODEL_NAME:/engine \
|
-v "$DESTINATION":/engine \
|
||||||
huggingface/optimum-nvidia \
|
-e HF_TOKEN=$HF_TOKEN \
|
||||||
optimum-cli export trtllm \
|
-e HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
|
huggingface/optimum-nvidia:v0.1.0b9-py310 \
|
||||||
|
bash -c "optimum-cli export trtllm \
|
||||||
--tp=1 \
|
--tp=1 \
|
||||||
--pp=1 \
|
--pp=1 \
|
||||||
--max-batch-size=128 \
|
--max-batch-size=64 \
|
||||||
--max-input-length 4096 \
|
--max-input-length 4096 \
|
||||||
--max-output-length 8192 \
|
--max-output-length 8192 \
|
||||||
--max-beams-width=1 \
|
--max-beams-width=1 \
|
||||||
--destination /engine \
|
--destination /tmp/engine \
|
||||||
$MODEL_NAME
|
$MODEL_NAME && cp -rL /tmp/engine/* /engine/"
|
||||||
```
|
```
|
||||||
|
|
||||||
Your compiled engine will be saved in the `/tmp/engines/$MODEL_NAME` directory.
|
Your compiled engine will be saved in the `/tmp/engines/$MODEL_NAME` directory, in a subfolder named after the GPU used to compile the model.
|
||||||
|
|
||||||
## Using the TRTLLM backend
|
## Using the TRTLLM backend
|
||||||
|
|
||||||
Run TGI-TRTLLM Docker image with the compiled engine:
|
Run TGI-TRTLLM Docker image with the compiled engine:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
MODEL_NAME="meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
DESTINATION="/tmp/engines/$MODEL_NAME"
|
||||||
|
HF_TOKEN="hf_xxx"
|
||||||
docker run \
|
docker run \
|
||||||
--gpus 1 \
|
--gpus 1 \
|
||||||
|
--shm-size=1g \
|
||||||
-it \
|
-it \
|
||||||
--rm \
|
--rm \
|
||||||
-p 3000:3000 \
|
-p 3000:3000 \
|
||||||
-e MODEL=$MODEL_NAME \
|
-e MODEL=$MODEL_NAME \
|
||||||
-e PORT=3000 \
|
-e PORT=3000 \
|
||||||
-e HF_TOKEN='hf_XXX' \
|
-e HF_TOKEN=$HF_TOKEN \
|
||||||
-v /tmp/engines/$MODEL_NAME:/data \
|
-v "$DESTINATION"/<YOUR_GPU_ARCHITECTURE>/engines:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:latest-trtllm \
|
ghcr.io/huggingface/text-generation-inference:latest-trtllm \
|
||||||
--executor-worker executorWorker \
|
--model-id /data/ \
|
||||||
--model-id /data/$MODEL_NAME
|
--tokenizer-name $MODEL_NAME
|
||||||
```
|
```
|
||||||
|
|
||||||
## Development
|
## Development
|
||||||
|
|
||||||
To develop TRTLLM backend, you can use [dev containers](https://containers.dev/) located in
|
To develop TRTLLM backend, you can use [dev containers](https://containers.dev/) with the following `.devcontainer.json` file:
|
||||||
`.devcontainer` directory.
|
```json
|
||||||
|
{
|
||||||
|
"name": "CUDA",
|
||||||
|
"build": {
|
||||||
|
"dockerfile": "Dockerfile_trtllm",
|
||||||
|
"context": ".."
|
||||||
|
},
|
||||||
|
"remoteEnv": {
|
||||||
|
"PATH": "${containerEnv:PATH}:/usr/local/cuda/bin",
|
||||||
|
"LD_LIBRARY_PATH": "$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64",
|
||||||
|
"XLA_FLAGS": "--xla_gpu_cuda_data_dir=/usr/local/cuda"
|
||||||
|
},
|
||||||
|
"customizations" : {
|
||||||
|
"jetbrains" : {
|
||||||
|
"backend" : "CLion"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
and `Dockerfile_trtllm`:
|
||||||
|
|
||||||
|
```Dockerfile
|
||||||
|
ARG cuda_arch_list="75-real;80-real;86-real;89-real;90-real"
|
||||||
|
ARG build_type=release
|
||||||
|
ARG ompi_version=4.1.7
|
||||||
|
|
||||||
|
# CUDA dependent dependencies resolver stage
|
||||||
|
FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu24.04 AS cuda-builder
|
||||||
|
|
||||||
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
||||||
|
build-essential \
|
||||||
|
cmake \
|
||||||
|
curl \
|
||||||
|
gcc-14 \
|
||||||
|
g++-14 \
|
||||||
|
git \
|
||||||
|
git-lfs \
|
||||||
|
lld \
|
||||||
|
libssl-dev \
|
||||||
|
libucx-dev \
|
||||||
|
libasan8 \
|
||||||
|
libubsan1 \
|
||||||
|
ninja-build \
|
||||||
|
pkg-config \
|
||||||
|
pipx \
|
||||||
|
python3 \
|
||||||
|
python3-dev \
|
||||||
|
python3-setuptools \
|
||||||
|
tar \
|
||||||
|
wget --no-install-recommends && \
|
||||||
|
pipx ensurepath
|
||||||
|
|
||||||
|
ENV TGI_INSTALL_PREFIX=/usr/local/tgi
|
||||||
|
ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt
|
||||||
|
|
||||||
|
# Install OpenMPI
|
||||||
|
FROM cuda-builder AS mpi-builder
|
||||||
|
WORKDIR /opt/src/mpi
|
||||||
|
|
||||||
|
ARG ompi_version
|
||||||
|
ENV OMPI_VERSION=${ompi_version}
|
||||||
|
ENV OMPI_TARBALL_FILENAME=openmpi-${OMPI_VERSION}.tar.bz2
|
||||||
|
ADD --checksum=sha256:54a33cb7ad81ff0976f15a6cc8003c3922f0f3d8ceed14e1813ef3603f22cd34 \
|
||||||
|
https://download.open-mpi.org/release/open-mpi/v4.1/${OMPI_TARBALL_FILENAME} .
|
||||||
|
|
||||||
|
RUN tar --strip-components=1 -xf ${OMPI_TARBALL_FILENAME} &&\
|
||||||
|
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \
|
||||||
|
make -j all && \
|
||||||
|
make install && \
|
||||||
|
rm -rf ${OMPI_TARBALL_FILENAME}/..
|
||||||
|
|
||||||
|
# Install TensorRT
|
||||||
|
FROM cuda-builder AS trt-builder
|
||||||
|
COPY backends/trtllm/scripts/install_tensorrt.sh /opt/install_tensorrt.sh
|
||||||
|
RUN chmod +x /opt/install_tensorrt.sh && \
|
||||||
|
/opt/install_tensorrt.sh
|
||||||
|
|
||||||
|
# Build Backend
|
||||||
|
FROM cuda-builder AS tgi-builder
|
||||||
|
WORKDIR /usr/src/text-generation-inference
|
||||||
|
|
||||||
|
# Scoped global args reuse
|
||||||
|
ARG cuda_arch_list
|
||||||
|
ARG build_type
|
||||||
|
ARG sccache_gha_enabled
|
||||||
|
ARG actions_cache_url
|
||||||
|
ARG actions_runtime_token
|
||||||
|
|
||||||
|
# Install Rust
|
||||||
|
ENV PATH="/root/.cargo/bin:$PATH"
|
||||||
|
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \
|
||||||
|
chmod -R a+w /root/.rustup && \
|
||||||
|
chmod -R a+w /root/.cargo && \
|
||||||
|
cargo install sccache --locked
|
||||||
|
|
||||||
|
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
|
||||||
|
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig"
|
||||||
|
ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt"
|
||||||
|
|
||||||
|
ENV USE_LLD_LINKER=ON
|
||||||
|
ENV CUDA_ARCH_LIST=${cuda_arch_list}
|
||||||
|
```
|
||||||
|
@ -19,6 +19,6 @@ docker run --gpus all \
|
|||||||
--shm-size 1g \
|
--shm-size 1g \
|
||||||
-e HF_TOKEN=$token \
|
-e HF_TOKEN=$token \
|
||||||
-p 8080:80 \
|
-p 8080:80 \
|
||||||
-v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.1 \
|
-v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.0 \
|
||||||
--model-id $model
|
--model-id $model
|
||||||
```
|
```
|
||||||
|
@ -19,7 +19,7 @@ bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models.
|
|||||||
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
|
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.1 --model-id $model --quantize bitsandbytes
|
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.0 --model-id $model --quantize bitsandbytes
|
||||||
```
|
```
|
||||||
|
|
||||||
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
|
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
|
||||||
@ -27,7 +27,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf
|
|||||||
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
|
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.1 --model-id $model --quantize bitsandbytes-nf4
|
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.0 --model-id $model --quantize bitsandbytes-nf4
|
||||||
```
|
```
|
||||||
|
|
||||||
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
|
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
|
||||||
@ -48,7 +48,7 @@ $$({\hat{W}_{l}}^{*} = argmin_{\hat{W_{l}}} ||W_{l}X-\hat{W}_{l}X||^{2}_{2})$$
|
|||||||
TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇
|
TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.1 --model-id $model --quantize gptq
|
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.0 --model-id $model --quantize gptq
|
||||||
```
|
```
|
||||||
|
|
||||||
Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI.
|
Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI.
|
||||||
|
@ -27,7 +27,7 @@ You can check a few existing fine-tunes for popular models:
|
|||||||
- [text-generation-inference/Mistral-7B-Instruct-v0.2-medusa](https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa)
|
- [text-generation-inference/Mistral-7B-Instruct-v0.2-medusa](https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa)
|
||||||
|
|
||||||
|
|
||||||
In order to create your own medusa heads for your own finetune, you should check own the original medusa repo. [../basic_tutorials/train_medusa.md](../basic_tutorials/train_medusa.md)
|
In order to create your own medusa heads for your own finetune, you should check own the original medusa repo. Read for more in [Train Medusa](../basic_tutorials/train_medusa#training).
|
||||||
|
|
||||||
|
|
||||||
In order to use medusa models in TGI, simply point to a medusa enabled model, and everything will load automatically.
|
In order to use medusa models in TGI, simply point to a medusa enabled model, and everything will load automatically.
|
||||||
|
@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
|||||||
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
|
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
|
||||||
--device=/dev/kfd --device=/dev/dri --group-add video \
|
--device=/dev/kfd --device=/dev/dri --group-add video \
|
||||||
--ipc=host --shm-size 256g --net host -v $volume:/data \
|
--ipc=host --shm-size 256g --net host -v $volume:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:3.0.1-rocm \
|
ghcr.io/huggingface/text-generation-inference:3.1.0-rocm \
|
||||||
--model-id $model
|
--model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
|||||||
docker run --rm --privileged --cap-add=sys_nice \
|
docker run --rm --privileged --cap-add=sys_nice \
|
||||||
--device=/dev/dri \
|
--device=/dev/dri \
|
||||||
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:3.0.1-intel-xpu \
|
ghcr.io/huggingface/text-generation-inference:3.1.0-intel-xpu \
|
||||||
--model-id $model --cuda-graphs 0
|
--model-id $model --cuda-graphs 0
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
|||||||
docker run --rm --privileged --cap-add=sys_nice \
|
docker run --rm --privileged --cap-add=sys_nice \
|
||||||
--device=/dev/dri \
|
--device=/dev/dri \
|
||||||
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:3.0.1-intel-cpu \
|
ghcr.io/huggingface/text-generation-inference:3.1.0-intel-cpu \
|
||||||
--model-id $model --cuda-graphs 0
|
--model-id $model --cuda-graphs 0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
|
|||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
|
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:3.0.1 \
|
ghcr.io/huggingface/text-generation-inference:3.1.0 \
|
||||||
--model-id $model
|
--model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
|
3
docs/source/installation_tpu.md
Normal file
3
docs/source/installation_tpu.md
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# Using TGI with Google TPUs
|
||||||
|
|
||||||
|
Check out this [guide](https://huggingface.co/docs/optimum-tpu) on how to serve models with TGI on TPUs.
|
@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
|
|||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:3.0.1 \
|
ghcr.io/huggingface/text-generation-inference:3.1.0 \
|
||||||
--model-id $model
|
--model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -96,7 +96,7 @@ curl 127.0.0.1:8080/generate \
|
|||||||
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
|
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run ghcr.io/huggingface/text-generation-inference:3.0.1 --help
|
docker run ghcr.io/huggingface/text-generation-inference:3.1.0 --help
|
||||||
```
|
```
|
||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
@ -163,7 +163,7 @@ hub = {
|
|||||||
|
|
||||||
# create Hugging Face Model Class
|
# create Hugging Face Model Class
|
||||||
huggingface_model = HuggingFaceModel(
|
huggingface_model = HuggingFaceModel(
|
||||||
image_uri=get_huggingface_llm_image_uri("huggingface",version="3.0.1"),
|
image_uri=get_huggingface_llm_image_uri("huggingface",version="3.1.0"),
|
||||||
env=hub,
|
env=hub,
|
||||||
role=role,
|
role=role,
|
||||||
)
|
)
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
Text Generation Inference enables serving optimized models. The following sections list which models (VLMs & LLMs) are supported.
|
Text Generation Inference enables serving optimized models. The following sections list which models (VLMs & LLMs) are supported.
|
||||||
|
|
||||||
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
|
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
|
||||||
|
- [Deepseek V3](https://huggingface.co/deepseek-ai/DeepSeek-V3)
|
||||||
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
|
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
|
||||||
- [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (Multimodal)
|
- [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (Multimodal)
|
||||||
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
|
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
Text Generation Inference collects anonymous usage statistics to help us improve the service. The collected data is used to improve TGI and to understand what causes failures. The data is collected transparently and any sensitive information is omitted.
|
Text Generation Inference collects anonymous usage statistics to help us improve the service. The collected data is used to improve TGI and to understand what causes failures. The data is collected transparently and any sensitive information is omitted.
|
||||||
|
|
||||||
Data is sent twice, once on server startup and once when server stops. Also, usage statistics are only enabled when TGI is running in docker to avoid collecting data then TGI runs directly on the host machine.
|
Usage statistics are collected only when TGI is running in a Docker container. This prevents data collection when TGI is run directly on the host machine. The collected data includes startup and shutdown events, as well as a heartbeat signal sent every 15 minutes.
|
||||||
|
|
||||||
## What data is collected
|
## What data is collected
|
||||||
|
|
||||||
|
30
flake.lock
30
flake.lock
@ -108,11 +108,11 @@
|
|||||||
"pre-commit-hooks": "pre-commit-hooks_3"
|
"pre-commit-hooks": "pre-commit-hooks_3"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1732039290,
|
"lastModified": 1734429562,
|
||||||
"narHash": "sha256-LQKY7bShf2H9kJouxa9ZspfdrulnZF9o4kLTqGqCDYM=",
|
"narHash": "sha256-V2XNs3Ir8WXNHdocfzkR/fu0FzkZ9uTDJkVecxJrGmQ=",
|
||||||
"owner": "nix-community",
|
"owner": "nix-community",
|
||||||
"repo": "crate2nix",
|
"repo": "crate2nix",
|
||||||
"rev": "9ff208ce7f5a482272b1bcefbe363c772d7ff914",
|
"rev": "8537c2d7cb623679aaeff62c4c4c43a91566ab09",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@ -305,11 +305,11 @@
|
|||||||
},
|
},
|
||||||
"flake-compat_4": {
|
"flake-compat_4": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1696426674,
|
"lastModified": 1733328505,
|
||||||
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=",
|
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
|
||||||
"owner": "edolstra",
|
"owner": "edolstra",
|
||||||
"repo": "flake-compat",
|
"repo": "flake-compat",
|
||||||
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33",
|
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@ -718,11 +718,11 @@
|
|||||||
},
|
},
|
||||||
"nixpkgs_6": {
|
"nixpkgs_6": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1732034459,
|
"lastModified": 1737453259,
|
||||||
"narHash": "sha256-Zais/zMRuJdlALidkUgEuasXOd37ZZLqkPkF9bIYSrY=",
|
"narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=",
|
||||||
"owner": "danieldk",
|
"owner": "danieldk",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "40280e7bf9743cdf563494db4ece2a43aa674fa8",
|
"rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@ -853,11 +853,11 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1732242723,
|
"lastModified": 1737685583,
|
||||||
"narHash": "sha256-NWI8csIK0ujFlFuEXKnoc+7hWoCiEtINK9r48LUUMeU=",
|
"narHash": "sha256-p+NVABRpGi+pT+xxf9HcLcFVxG6L+vEEy+NwzB9T0f8=",
|
||||||
"owner": "oxalica",
|
"owner": "oxalica",
|
||||||
"repo": "rust-overlay",
|
"repo": "rust-overlay",
|
||||||
"rev": "a229311fcb45b88a95fdfa5cecd8349c809a272a",
|
"rev": "eb64cbcc8eee0fa87ebded92805280d2ec97415a",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@ -978,11 +978,11 @@
|
|||||||
"nixpkgs": "nixpkgs_6"
|
"nixpkgs": "nixpkgs_6"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1736436388,
|
"lastModified": 1738323634,
|
||||||
"narHash": "sha256-CIyxVPpM9RrSwthNT/4DQ10YPk/uwzP7AeE83kBNsrE=",
|
"narHash": "sha256-lKPzgEm7pEuQJVhacsxFHqg1MOtrUMZvr+9IuJzC5J4=",
|
||||||
"owner": "huggingface",
|
"owner": "huggingface",
|
||||||
"repo": "text-generation-inference-nix",
|
"repo": "text-generation-inference-nix",
|
||||||
"rev": "5103c3fb1f9ad1fd33b6e09ff05e957884b112d5",
|
"rev": "eb5fede2756f544f75e01f55a4097f9c9a8c5005",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
@ -562,6 +562,7 @@ def launcher(event_loop):
|
|||||||
docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
|
docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
|
||||||
]
|
]
|
||||||
|
|
||||||
|
client.api.timeout = 1000
|
||||||
container = client.containers.run(
|
container = client.containers.run(
|
||||||
DOCKER_IMAGE,
|
DOCKER_IMAGE,
|
||||||
command=args,
|
command=args,
|
||||||
@ -573,7 +574,7 @@ def launcher(event_loop):
|
|||||||
devices=devices,
|
devices=devices,
|
||||||
volumes=volumes,
|
volumes=volumes,
|
||||||
ports={"80/tcp": port},
|
ports={"80/tcp": port},
|
||||||
healthcheck={"timeout": int(60 * 1e9), "retries": 2}, # 60s
|
healthcheck={"timeout": int(180 * 1e9), "retries": 2}, # 60s
|
||||||
shm_size="1G",
|
shm_size="1G",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -0,0 +1,73 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 2284,
|
||||||
|
"logprob": -0.9355469,
|
||||||
|
"special": false,
|
||||||
|
"text": "():"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 303,
|
||||||
|
"logprob": -0.40795898,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1489,
|
||||||
|
"logprob": -0.27954102,
|
||||||
|
"special": false,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 459,
|
||||||
|
"logprob": -0.6142578,
|
||||||
|
"special": false,
|
||||||
|
"text": "(\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8302,
|
||||||
|
"logprob": -0.68310547,
|
||||||
|
"special": false,
|
||||||
|
"text": "Hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10914,
|
||||||
|
"logprob": -1.4599609,
|
||||||
|
"special": false,
|
||||||
|
"text": " World"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16013,
|
||||||
|
"logprob": -0.80126953,
|
||||||
|
"special": false,
|
||||||
|
"text": "!\")"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -0.625,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -0.23242188,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 610,
|
||||||
|
"logprob": -1.2294922,
|
||||||
|
"special": false,
|
||||||
|
"text": "def"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
|
||||||
|
}
|
@ -0,0 +1,373 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 60,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 40,
|
||||||
|
"logprob": -0.7944336,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 494,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 447,
|
||||||
|
"logprob": -0.1796875,
|
||||||
|
"special": false,
|
||||||
|
"text": " ["
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 9009,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "markdown"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 98,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "]"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 37402,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " slideshow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8492,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "={\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7277,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "slide"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 100,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "_"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 700,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "type"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 582,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 332,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7277,
|
||||||
|
"logprob": -0.06994629,
|
||||||
|
"special": false,
|
||||||
|
"text": "slide"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3667,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\"}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 40,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 607,
|
||||||
|
"logprob": -0.8261719,
|
||||||
|
"special": false,
|
||||||
|
"text": " #"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 244,
|
||||||
|
"logprob": -1.8574219,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 55,
|
||||||
|
"logprob": -1.4541016,
|
||||||
|
"special": false,
|
||||||
|
"text": "2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 51,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6208,
|
||||||
|
"logprob": -0.9794922,
|
||||||
|
"special": false,
|
||||||
|
"text": " What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 458,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 341,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10609,
|
||||||
|
"logprob": -0.69189453,
|
||||||
|
"special": false,
|
||||||
|
"text": " difference"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3761,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " between"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 331,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1168,
|
||||||
|
"logprob": -0.27172852,
|
||||||
|
"special": false,
|
||||||
|
"text": " list"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 480,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " and"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 331,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8871,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " tuple"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 68,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 40,
|
||||||
|
"logprob": -1.3359375,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 40,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 449,
|
||||||
|
"logprob": -0.03164673,
|
||||||
|
"special": false,
|
||||||
|
"text": " -"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 418,
|
||||||
|
"logprob": -1.0947266,
|
||||||
|
"special": false,
|
||||||
|
"text": " A"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1168,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " list"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 458,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 331,
|
||||||
|
"logprob": -0.3305664,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 14792,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " mutable"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6645,
|
||||||
|
"logprob": -0.40478516,
|
||||||
|
"special": false,
|
||||||
|
"text": " sequence"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 451,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4725,
|
||||||
|
"logprob": -0.50390625,
|
||||||
|
"special": false,
|
||||||
|
"text": " elements"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 49,
|
||||||
|
"logprob": -2.1269531,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2236,
|
||||||
|
"logprob": -0.1427002,
|
||||||
|
"special": false,
|
||||||
|
"text": " while"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 331,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8871,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " tuple"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 458,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 619,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " an"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26079,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " immutable"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6645,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " sequence"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 451,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4725,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " elements"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 51,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 40,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 449,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " -"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\n# + [markdown] slideshow={\"slide_type\": \"slide\"}\n# # 2. What is the difference between a list and a tuple?\n#\n# - A list is a mutable sequence of elements, while a tuple is an immutable sequence of elements.\n# -"
|
||||||
|
}
|
@ -0,0 +1,294 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -1.9091797,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -1.0478516,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 40,
|
||||||
|
"logprob": -3.015625,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 494,
|
||||||
|
"logprob": -1.4228516,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 447,
|
||||||
|
"logprob": -1.1025391,
|
||||||
|
"special": false,
|
||||||
|
"text": " ["
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 9009,
|
||||||
|
"logprob": -0.0008444786,
|
||||||
|
"special": false,
|
||||||
|
"text": "markdown"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 98,
|
||||||
|
"logprob": -8.8095665e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": "]"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 37402,
|
||||||
|
"logprob": -0.5810547,
|
||||||
|
"special": false,
|
||||||
|
"text": " slideshow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8492,
|
||||||
|
"logprob": -0.00022864342,
|
||||||
|
"special": false,
|
||||||
|
"text": "={\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7277,
|
||||||
|
"logprob": -0.00030994415,
|
||||||
|
"special": false,
|
||||||
|
"text": "slide"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -1.9091797,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -1.0478516,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 40,
|
||||||
|
"logprob": -3.015625,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 494,
|
||||||
|
"logprob": -1.4228516,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 447,
|
||||||
|
"logprob": -1.1025391,
|
||||||
|
"special": false,
|
||||||
|
"text": " ["
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 9009,
|
||||||
|
"logprob": -0.0008444786,
|
||||||
|
"special": false,
|
||||||
|
"text": "markdown"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 98,
|
||||||
|
"logprob": -8.8095665e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": "]"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 37402,
|
||||||
|
"logprob": -0.5810547,
|
||||||
|
"special": false,
|
||||||
|
"text": " slideshow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8492,
|
||||||
|
"logprob": -0.00022864342,
|
||||||
|
"special": false,
|
||||||
|
"text": "={\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7277,
|
||||||
|
"logprob": -0.00030994415,
|
||||||
|
"special": false,
|
||||||
|
"text": "slide"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -1.9091797,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -1.0478516,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 40,
|
||||||
|
"logprob": -3.015625,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 494,
|
||||||
|
"logprob": -1.4228516,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 447,
|
||||||
|
"logprob": -1.1025391,
|
||||||
|
"special": false,
|
||||||
|
"text": " ["
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 9009,
|
||||||
|
"logprob": -0.0008444786,
|
||||||
|
"special": false,
|
||||||
|
"text": "markdown"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 98,
|
||||||
|
"logprob": -8.8095665e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": "]"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 37402,
|
||||||
|
"logprob": -0.5810547,
|
||||||
|
"special": false,
|
||||||
|
"text": " slideshow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8492,
|
||||||
|
"logprob": -0.00022864342,
|
||||||
|
"special": false,
|
||||||
|
"text": "={\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7277,
|
||||||
|
"logprob": -0.00030994415,
|
||||||
|
"special": false,
|
||||||
|
"text": "slide"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -1.9091797,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -1.0478516,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 40,
|
||||||
|
"logprob": -3.015625,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 494,
|
||||||
|
"logprob": -1.4228516,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 447,
|
||||||
|
"logprob": -1.1025391,
|
||||||
|
"special": false,
|
||||||
|
"text": " ["
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 9009,
|
||||||
|
"logprob": -0.0008444786,
|
||||||
|
"special": false,
|
||||||
|
"text": "markdown"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 98,
|
||||||
|
"logprob": -8.8095665e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": "]"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 37402,
|
||||||
|
"logprob": -0.5810547,
|
||||||
|
"special": false,
|
||||||
|
"text": " slideshow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8492,
|
||||||
|
"logprob": -0.00022864342,
|
||||||
|
"special": false,
|
||||||
|
"text": "={\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7277,
|
||||||
|
"logprob": -0.00030994415,
|
||||||
|
"special": false,
|
||||||
|
"text": "slide"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\n# + [markdown] slideshow={\"slide"
|
||||||
|
}
|
||||||
|
]
|
@ -0,0 +1,71 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 100,
|
||||||
|
"logprob": -0.9824219,
|
||||||
|
"special": false,
|
||||||
|
"text": "_"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5879,
|
||||||
|
"logprob": -0.3017578,
|
||||||
|
"special": false,
|
||||||
|
"text": "world"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2284,
|
||||||
|
"logprob": -0.68652344,
|
||||||
|
"special": false,
|
||||||
|
"text": "():"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 303,
|
||||||
|
"logprob": -0.27734375,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1489,
|
||||||
|
"logprob": -0.4482422,
|
||||||
|
"special": false,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 459,
|
||||||
|
"logprob": -0.54248047,
|
||||||
|
"special": false,
|
||||||
|
"text": "(\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8302,
|
||||||
|
"logprob": -0.4296875,
|
||||||
|
"special": false,
|
||||||
|
"text": "Hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10914,
|
||||||
|
"logprob": -0.8544922,
|
||||||
|
"special": false,
|
||||||
|
"text": " World"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16013,
|
||||||
|
"logprob": -0.7573242,
|
||||||
|
"special": false,
|
||||||
|
"text": "!\")"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -0.81347656,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": "_world():\n print(\"Hello World!\")\n"
|
||||||
|
}
|
79
integration-tests/models/test_flash_starcoder2_lora.py
Normal file
79
integration-tests/models/test_flash_starcoder2_lora.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_starcoder2_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"bigcode/starcoder2-3b", lora_adapters=["smangrul/starcoder-3b-hugcoder"]
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_starcoder2(flash_starcoder2_handle):
|
||||||
|
await flash_starcoder2_handle.health(300)
|
||||||
|
return flash_starcoder2_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
|
||||||
|
response = await flash_starcoder2.generate(
|
||||||
|
"def print_hello", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot):
|
||||||
|
response = await flash_starcoder2.generate(
|
||||||
|
"who are you?",
|
||||||
|
max_new_tokens=60,
|
||||||
|
temperature=0.2,
|
||||||
|
top_p=0.95,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 60
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_starcoder2_load(
|
||||||
|
flash_starcoder2, generate_load, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_starcoder2, "who are you?", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_starcoder2_with_hugcode_adapter(
|
||||||
|
flash_starcoder2, response_snapshot
|
||||||
|
):
|
||||||
|
response = requests.post(
|
||||||
|
f"{flash_starcoder2.base_url}/generate",
|
||||||
|
headers=flash_starcoder2.headers,
|
||||||
|
json={
|
||||||
|
"inputs": "def print_hello",
|
||||||
|
"parameters": {
|
||||||
|
"max_new_tokens": 10,
|
||||||
|
"adapter_id": "smangrul/starcoder-3b-hugcoder",
|
||||||
|
"details": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["generated_text"] == '_world():\n print("Hello World!")\n'
|
||||||
|
|
||||||
|
assert data == response_snapshot
|
@ -25,21 +25,23 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap
|
|||||||
assert response == generous_response_snapshot
|
assert response == generous_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.release
|
# Deactivated because it's flaky
|
||||||
@pytest.mark.asyncio
|
# Only this model seems affected and it's only a logprob precision issue.
|
||||||
async def test_flash_starcoder_gptq_default_params(
|
# @pytest.mark.release
|
||||||
flash_starcoder_gptq, generous_response_snapshot
|
# @pytest.mark.asyncio
|
||||||
):
|
# async def test_flash_starcoder_gptq_default_params(
|
||||||
response = await flash_starcoder_gptq.generate(
|
# flash_starcoder_gptq, generous_response_snapshot
|
||||||
"def geometric_mean(L: List[float]):",
|
# ):
|
||||||
max_new_tokens=20,
|
# response = await flash_starcoder_gptq.generate(
|
||||||
temperature=0.2,
|
# "def geometric_mean(L: List[float]):",
|
||||||
top_p=0.95,
|
# max_new_tokens=20,
|
||||||
decoder_input_details=True,
|
# temperature=0.2,
|
||||||
seed=0,
|
# top_p=0.95,
|
||||||
)
|
# decoder_input_details=True,
|
||||||
assert response.details.generated_tokens == 2
|
# seed=0,
|
||||||
assert response == generous_response_snapshot
|
# )
|
||||||
|
# assert response.details.generated_tokens == 2
|
||||||
|
# assert response == generous_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
|
@ -5,7 +5,6 @@ use hf_hub::{
|
|||||||
};
|
};
|
||||||
use nix::sys::signal::{self, Signal};
|
use nix::sys::signal::{self, Signal};
|
||||||
use nix::unistd::Pid;
|
use nix::unistd::Pid;
|
||||||
use regex::Regex;
|
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::ffi::OsString;
|
use std::ffi::OsString;
|
||||||
@ -144,7 +143,9 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) {
|
let fallback_attention = if compute_capability.is_none()
|
||||||
|
|| matches!(compute_capability, Some((major, _)) if major < 8)
|
||||||
|
{
|
||||||
"paged"
|
"paged"
|
||||||
} else {
|
} else {
|
||||||
"flashdecoding"
|
"flashdecoding"
|
||||||
@ -1631,8 +1632,10 @@ enum Gpu {
|
|||||||
L40,
|
L40,
|
||||||
L40S,
|
L40S,
|
||||||
A10G,
|
A10G,
|
||||||
|
A40,
|
||||||
H100,
|
H100,
|
||||||
A100,
|
A100,
|
||||||
|
H200,
|
||||||
Unknown(String),
|
Unknown(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1651,6 +1654,7 @@ impl From<&str> for Gpu {
|
|||||||
"nvidia-l40" => Gpu::L40,
|
"nvidia-l40" => Gpu::L40,
|
||||||
"nvidia-l40s" => Gpu::L40S,
|
"nvidia-l40s" => Gpu::L40S,
|
||||||
"nvidia-a10g" => Gpu::A10G,
|
"nvidia-a10g" => Gpu::A10G,
|
||||||
|
"nvidia-a40" => Gpu::A40,
|
||||||
"nvidia-h100-80gb-hbm3" => Gpu::H100,
|
"nvidia-h100-80gb-hbm3" => Gpu::H100,
|
||||||
"nvidia-h100-nvl" => Gpu::H100,
|
"nvidia-h100-nvl" => Gpu::H100,
|
||||||
"nvidia-h100" => Gpu::H100,
|
"nvidia-h100" => Gpu::H100,
|
||||||
@ -1658,6 +1662,7 @@ impl From<&str> for Gpu {
|
|||||||
"nvidia-a100-sxm4-40gb" => Gpu::A100,
|
"nvidia-a100-sxm4-40gb" => Gpu::A100,
|
||||||
"nvidia-a100-80gb-pcie" => Gpu::A100,
|
"nvidia-a100-80gb-pcie" => Gpu::A100,
|
||||||
"nvidia-a100" => Gpu::A100,
|
"nvidia-a100" => Gpu::A100,
|
||||||
|
"nvidia-h200" => Gpu::H200,
|
||||||
card => Gpu::Unknown(card.to_string()),
|
card => Gpu::Unknown(card.to_string()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1672,8 +1677,10 @@ impl std::fmt::Display for Gpu {
|
|||||||
Gpu::L40 => write!(f, "nvida-l40"),
|
Gpu::L40 => write!(f, "nvida-l40"),
|
||||||
Gpu::L40S => write!(f, "nvida-l40s"),
|
Gpu::L40S => write!(f, "nvida-l40s"),
|
||||||
Gpu::A10G => write!(f, "nvidia-a10g"),
|
Gpu::A10G => write!(f, "nvidia-a10g"),
|
||||||
|
Gpu::A40 => write!(f, "nvidia-a40"),
|
||||||
Gpu::H100 => write!(f, "nvidia-h100-80fb-hbm3"),
|
Gpu::H100 => write!(f, "nvidia-h100-80fb-hbm3"),
|
||||||
Gpu::A100 => write!(f, "nvida-a100-sxm4-80gb"),
|
Gpu::A100 => write!(f, "nvida-a100-sxm4-80gb"),
|
||||||
|
Gpu::H200 => write!(f, "nvida-h200"),
|
||||||
Gpu::Unknown(card) => write!(f, "{}", card),
|
Gpu::Unknown(card) => write!(f, "{}", card),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1695,11 +1702,16 @@ impl ComputeType {
|
|||||||
Gpu::L40S => Some(363 * 10u64.pow(12)),
|
Gpu::L40S => Some(363 * 10u64.pow(12)),
|
||||||
// https://www.nvidia.com/en-us/data-center/products/a10-gpu/
|
// https://www.nvidia.com/en-us/data-center/products/a10-gpu/
|
||||||
Gpu::A10G => Some(125 * 10u64.pow(12)),
|
Gpu::A10G => Some(125 * 10u64.pow(12)),
|
||||||
|
// https://www.nvidia.com/en-us/data-center/a40/
|
||||||
|
// https://images.nvidia.com/content/Solutions/data-center/a40/nvidia-a40-datasheet.pdf
|
||||||
|
Gpu::A40 => Some(149 * 10u64.pow(12)),
|
||||||
|
// https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
|
||||||
|
Gpu::A100 => Some(312 * 10u64.pow(12)),
|
||||||
// https://www.nvidia.com/en-us/data-center/h100/
|
// https://www.nvidia.com/en-us/data-center/h100/
|
||||||
// https://www.techpowerup.com/gpu-specs/docs/nvidia-gh100-architecture.pdf
|
// https://www.techpowerup.com/gpu-specs/docs/nvidia-gh100-architecture.pdf
|
||||||
Gpu::H100 => Some(900 * 10u64.pow(12)),
|
Gpu::H100 => Some(900 * 10u64.pow(12)),
|
||||||
// https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
|
// https://www.nvidia.com/en-us/data-center/h200/
|
||||||
Gpu::A100 => Some(312 * 10u64.pow(12)),
|
Gpu::H200 => Some(989 * 10u64.pow(12)),
|
||||||
Gpu::Unknown(card) => {
|
Gpu::Unknown(card) => {
|
||||||
tracing::warn!("Unkown compute for card {card}");
|
tracing::warn!("Unkown compute for card {card}");
|
||||||
None
|
None
|
||||||
@ -2079,14 +2091,7 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
let cuda_graphs = match (&args.cuda_graphs, &quantize) {
|
let cuda_graphs = match (&args.cuda_graphs, &quantize) {
|
||||||
(Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
|
(Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
(
|
(None, Some(Quantization::Bitsandbytes)) => {
|
||||||
None,
|
|
||||||
Some(
|
|
||||||
Quantization::Bitsandbytes
|
|
||||||
| Quantization::BitsandbytesNf4
|
|
||||||
| Quantization::BitsandbytesFp4,
|
|
||||||
),
|
|
||||||
) => {
|
|
||||||
tracing::warn!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
|
tracing::warn!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
|
||||||
vec![]
|
vec![]
|
||||||
}
|
}
|
||||||
@ -2176,11 +2181,12 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// capture adapter_id, path, revision in format of adapter_id=path@revision
|
// capture adapter_id, path, revision in format of adapter_id=path@revision
|
||||||
let re = Regex::new(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$").unwrap();
|
// path is disabled beforehand.
|
||||||
if let Some(caps) = re.captures(adapter) {
|
let mut splits = adapter.split("@");
|
||||||
let adapter_id = caps.get(1).map_or("", |m| m.as_str());
|
let adapter_id = splits.next().ok_or_else(|| {
|
||||||
let revision = caps.get(3).map(|m| m.as_str());
|
LauncherError::ArgumentValidation("Missing adapter id".to_string())
|
||||||
|
})?;
|
||||||
|
let revision = splits.next();
|
||||||
download_convert_model(
|
download_convert_model(
|
||||||
adapter_id,
|
adapter_id,
|
||||||
revision,
|
revision,
|
||||||
@ -2190,12 +2196,6 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
running.clone(),
|
running.clone(),
|
||||||
false, // avoid merging lora adapters if using multi-lora
|
false, // avoid merging lora adapters if using multi-lora
|
||||||
)?;
|
)?;
|
||||||
} else {
|
|
||||||
return Err(LauncherError::ArgumentValidation(format!(
|
|
||||||
"Invalid LoRA adapter format: {}",
|
|
||||||
adapter
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -224,6 +224,8 @@ pub enum Config {
|
|||||||
Qwen2,
|
Qwen2,
|
||||||
Opt,
|
Opt,
|
||||||
T5,
|
T5,
|
||||||
|
DeepseekV2,
|
||||||
|
DeepseekV3,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
@ -40,6 +40,8 @@ pub trait Backend {
|
|||||||
fn start_health(&self) -> bool {
|
fn start_health(&self) -> bool {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> &'static str;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Inference struct
|
/// Inference struct
|
||||||
|
@ -79,7 +79,7 @@ impl TokenizerTrait for tokenizers::Tokenizer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> TokenizerTrait for PyTokenizer<'a> {
|
impl TokenizerTrait for PyTokenizer<'_> {
|
||||||
fn encode_trait(
|
fn encode_trait(
|
||||||
&self,
|
&self,
|
||||||
query: String,
|
query: String,
|
||||||
@ -460,7 +460,7 @@ pub struct CompletionRequest {
|
|||||||
pub prompt: Prompt,
|
pub prompt: Prompt,
|
||||||
|
|
||||||
/// The maximum number of tokens that can be generated in the chat completion.
|
/// The maximum number of tokens that can be generated in the chat completion.
|
||||||
#[serde(default, alias = "max_completion_tokens")]
|
#[serde(default)]
|
||||||
#[schema(default = "1024", example = "32")]
|
#[schema(default = "1024", example = "32")]
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
|
|
||||||
@ -840,7 +840,7 @@ pub(crate) struct ChatRequest {
|
|||||||
pub top_logprobs: Option<u32>,
|
pub top_logprobs: Option<u32>,
|
||||||
|
|
||||||
/// The maximum number of tokens that can be generated in the chat completion.
|
/// The maximum number of tokens that can be generated in the chat completion.
|
||||||
#[serde(default)]
|
#[serde(default, alias = "max_completion_tokens")]
|
||||||
#[schema(default = "1024", example = "32")]
|
#[schema(default = "1024", example = "32")]
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
|
|
||||||
|
@ -54,6 +54,9 @@ use std::fs::File;
|
|||||||
use std::io::BufReader;
|
use std::io::BufReader;
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::select;
|
use tokio::select;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
@ -1819,9 +1822,9 @@ pub async fn run(
|
|||||||
HubTokenizerConfig::default()
|
HubTokenizerConfig::default()
|
||||||
});
|
});
|
||||||
|
|
||||||
let tokenizer: Tokenizer = {
|
let tokenizer: Result<Tokenizer, WebServerError> = {
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
pyo3::Python::with_gil(|py| -> PyResult<()> {
|
Python::with_gil(|py| -> PyResult<()> {
|
||||||
py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), trust_remote_code)?;
|
py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), trust_remote_code)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
})
|
})
|
||||||
@ -1832,16 +1835,16 @@ pub async fn run(
|
|||||||
let out = legacy_tokenizer_handle(config_filename.as_ref());
|
let out = legacy_tokenizer_handle(config_filename.as_ref());
|
||||||
out.ok_or(err)
|
out.ok_or(err)
|
||||||
})
|
})
|
||||||
.expect("We cannot load a tokenizer");
|
.map_err(|_| WebServerError::Tokenizer("Unable to load tokenizer.".to_string()))?;
|
||||||
let filename = "out/tokenizer.json";
|
let filename = "out/tokenizer.json";
|
||||||
if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) {
|
if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) {
|
||||||
Tokenizer::Rust(tok)
|
Ok(Tokenizer::Rust(tok))
|
||||||
} else {
|
} else {
|
||||||
Tokenizer::Python {
|
Ok(Tokenizer::Python {
|
||||||
tokenizer_name: tokenizer_name.clone(),
|
tokenizer_name: tokenizer_name.clone(),
|
||||||
revision: revision.clone(),
|
revision: revision.clone(),
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1895,17 +1898,34 @@ pub async fn run(
|
|||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats_level,
|
usage_stats_level,
|
||||||
|
backend.name(),
|
||||||
);
|
);
|
||||||
Some(usage_stats::UserAgent::new(reduced_args))
|
Some(usage_stats::UserAgent::new(reduced_args))
|
||||||
}
|
}
|
||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(ref ua) = user_agent {
|
let stop_usage_thread = Arc::new(AtomicBool::new(false));
|
||||||
|
let stop_usage_thread_clone = stop_usage_thread.clone();
|
||||||
|
if let Some(ua) = user_agent.clone() {
|
||||||
let start_event =
|
let start_event =
|
||||||
usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None);
|
usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None);
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
|
// send start event
|
||||||
start_event.send().await;
|
start_event.send().await;
|
||||||
|
let mut last_report = Instant::now();
|
||||||
|
while !stop_usage_thread_clone.load(Ordering::Relaxed) {
|
||||||
|
if last_report.elapsed() > Duration::from_secs(900) {
|
||||||
|
let report_event = usage_stats::UsageStatsEvent::new(
|
||||||
|
ua.clone(),
|
||||||
|
usage_stats::EventType::Ping,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
report_event.send().await;
|
||||||
|
last_report = Instant::now();
|
||||||
|
}
|
||||||
|
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||||
|
}
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
let compat_return_full_text = match &model_info.pipeline_tag {
|
let compat_return_full_text = match &model_info.pipeline_tag {
|
||||||
@ -1926,7 +1946,7 @@ pub async fn run(
|
|||||||
validation_workers,
|
validation_workers,
|
||||||
api_key,
|
api_key,
|
||||||
config,
|
config,
|
||||||
(tokenizer, tokenizer_config),
|
(tokenizer?, tokenizer_config),
|
||||||
(preprocessor_config, processor_config),
|
(preprocessor_config, processor_config),
|
||||||
hostname,
|
hostname,
|
||||||
port,
|
port,
|
||||||
@ -1943,6 +1963,7 @@ pub async fn run(
|
|||||||
.await;
|
.await;
|
||||||
|
|
||||||
if let Some(ua) = user_agent {
|
if let Some(ua) = user_agent {
|
||||||
|
stop_usage_thread.store(true, Ordering::Relaxed);
|
||||||
match result {
|
match result {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
let stop_event = usage_stats::UsageStatsEvent::new(
|
let stop_event = usage_stats::UsageStatsEvent::new(
|
||||||
@ -2419,8 +2440,13 @@ async fn start(
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Run server
|
// Run server
|
||||||
|
let listener = match tokio::net::TcpListener::bind(&addr).await {
|
||||||
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
Ok(listener) => listener,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Failed to bind to {addr}: {e}");
|
||||||
|
return Err(WebServerError::Axum(Box::new(e)));
|
||||||
|
}
|
||||||
|
};
|
||||||
axum::serve(listener, app)
|
axum::serve(listener, app)
|
||||||
.with_graceful_shutdown(shutdown_signal())
|
.with_graceful_shutdown(shutdown_signal())
|
||||||
.await
|
.await
|
||||||
@ -2535,4 +2561,6 @@ impl From<InferError> for Event {
|
|||||||
pub enum WebServerError {
|
pub enum WebServerError {
|
||||||
#[error("Axum error: {0}")]
|
#[error("Axum error: {0}")]
|
||||||
Axum(#[from] axum::BoxError),
|
Axum(#[from] axum::BoxError),
|
||||||
|
#[error("Tokenizer error: {0}")]
|
||||||
|
Tokenizer(String),
|
||||||
}
|
}
|
||||||
|
@ -43,6 +43,7 @@ pub enum EventType {
|
|||||||
Start,
|
Start,
|
||||||
Stop,
|
Stop,
|
||||||
Error,
|
Error,
|
||||||
|
Ping,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
@ -70,7 +71,7 @@ impl UsageStatsEvent {
|
|||||||
.post(TELEMETRY_URL)
|
.post(TELEMETRY_URL)
|
||||||
.headers(headers)
|
.headers(headers)
|
||||||
.body(body)
|
.body(body)
|
||||||
.timeout(Duration::from_secs(5))
|
.timeout(Duration::from_secs(10))
|
||||||
.send()
|
.send()
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
@ -96,6 +97,7 @@ pub struct Args {
|
|||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
usage_stats_level: UsageStatsLevel,
|
usage_stats_level: UsageStatsLevel,
|
||||||
|
backend_name: &'static str,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
impl Args {
|
||||||
@ -119,6 +121,7 @@ impl Args {
|
|||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
usage_stats_level: UsageStatsLevel,
|
usage_stats_level: UsageStatsLevel,
|
||||||
|
backend_name: &'static str,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
model_config,
|
model_config,
|
||||||
@ -139,6 +142,7 @@ impl Args {
|
|||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats_level,
|
usage_stats_level,
|
||||||
|
backend_name,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1229,12 +1229,11 @@ mod tests {
|
|||||||
assert!(
|
assert!(
|
||||||
chunks
|
chunks
|
||||||
== vec![
|
== vec![
|
||||||
Chunk::Text("test".to_string()).into(),
|
Chunk::Text("test".to_string()),
|
||||||
Chunk::Image(Image {
|
Chunk::Image(Image {
|
||||||
data: pixel_data.clone(),
|
data: pixel_data.clone(),
|
||||||
mimetype: "image/gif".to_string()
|
mimetype: "image/gif".to_string()
|
||||||
})
|
})
|
||||||
.into()
|
|
||||||
],
|
],
|
||||||
"Failed to process images",
|
"Failed to process images",
|
||||||
);
|
);
|
||||||
@ -1289,17 +1288,15 @@ mod tests {
|
|||||||
assert!(
|
assert!(
|
||||||
chunks
|
chunks
|
||||||
== vec![
|
== vec![
|
||||||
Chunk::Text("test".to_string()).into(),
|
Chunk::Text("test".to_string()),
|
||||||
|
Chunk::Image(Image {
|
||||||
|
data: pixel_data.clone(),
|
||||||
|
mimetype: "image/gif".to_string()
|
||||||
|
}),
|
||||||
Chunk::Image(Image {
|
Chunk::Image(Image {
|
||||||
data: pixel_data.clone(),
|
data: pixel_data.clone(),
|
||||||
mimetype: "image/gif".to_string()
|
mimetype: "image/gif".to_string()
|
||||||
})
|
})
|
||||||
.into(),
|
|
||||||
Chunk::Image(Image {
|
|
||||||
data: pixel_data.clone(),
|
|
||||||
mimetype: "image/gif".to_string()
|
|
||||||
})
|
|
||||||
.into()
|
|
||||||
],
|
],
|
||||||
"Failed to process images",
|
"Failed to process images",
|
||||||
);
|
);
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
[toolchain]
|
[toolchain]
|
||||||
# Released on: June 13, 2024
|
# Released on: June 13, 2024
|
||||||
# https://releases.rs/docs/1.79.0/
|
# https://releases.rs/docs/1.79.0/
|
||||||
channel = "1.80.1"
|
channel = "1.84.0"
|
||||||
components = ["rustfmt", "clippy"]
|
components = ["rustfmt", "clippy"]
|
||||||
|
@ -9,11 +9,14 @@ include Makefile-exllamav2
|
|||||||
include Makefile-flashinfer
|
include Makefile-flashinfer
|
||||||
|
|
||||||
unit-tests:
|
unit-tests:
|
||||||
|
pip install -U pip uv
|
||||||
|
uv pip install -e ".[dev]"
|
||||||
pytest -s -vv -m "not private" tests
|
pytest -s -vv -m "not private" tests
|
||||||
|
|
||||||
gen-server:
|
gen-server:
|
||||||
# Compile protos
|
# Compile protos
|
||||||
pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir
|
pip install -U pip uv
|
||||||
|
uv pip install ".[gen]"
|
||||||
mkdir text_generation_server/pb || true
|
mkdir text_generation_server/pb || true
|
||||||
python -m grpc_tools.protoc -I../proto/v3 --python_out=text_generation_server/pb \
|
python -m grpc_tools.protoc -I../proto/v3 --python_out=text_generation_server/pb \
|
||||||
--grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/v3/generate.proto
|
--grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/v3/generate.proto
|
||||||
@ -21,24 +24,14 @@ gen-server:
|
|||||||
touch text_generation_server/pb/__init__.py
|
touch text_generation_server/pb/__init__.py
|
||||||
|
|
||||||
install-server: gen-server
|
install-server: gen-server
|
||||||
pip install pip --upgrade
|
uv pip install -e ".[accelerate, compressed-tensors, quantize, peft, outlines]"
|
||||||
pip install -r requirements_cuda.txt
|
|
||||||
pip install -e ".[accelerate, compressed-tensors, quantize, peft, outlines]"
|
|
||||||
|
|
||||||
|
|
||||||
install: install-cuda
|
install: install-cuda
|
||||||
echo "Installed server"
|
echo "Installed server"
|
||||||
|
|
||||||
install-cuda: install-server install-flash-attention-v2-cuda install-flash-attention
|
install-cuda: install-server install-flash-attention-v2-cuda install-flash-attention
|
||||||
pip install -e ".[attention,bnb,marlin,moe]"
|
uv pip install -e ".[attention,bnb,marlin,moe]"
|
||||||
pip install nvidia-nccl-cu12==2.22.3
|
uv pip install nvidia-nccl-cu12==2.22.3
|
||||||
|
|
||||||
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
|
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
|
||||||
|
|
||||||
run-dev:
|
|
||||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
|
||||||
|
|
||||||
export-requirements:
|
|
||||||
poetry export -o requirements_cuda.txt --without-hashes
|
|
||||||
poetry export -o requirements_rocm.txt --without-hashes
|
|
||||||
poetry export -o requirements_intel.txt --without-hashes
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
install-flashinfer:
|
install-flashinfer:
|
||||||
# We need fsspec as an additional dependency, but
|
# We need fsspec as an additional dependency, but
|
||||||
# `pip install flashinfer` cannot resolve it.
|
# `pip install flashinfer` cannot resolve it.
|
||||||
pip install fsspec
|
pip install fsspec sympy==1.13.1 numpy
|
||||||
pip install flashinfer==0.2.0.post1 -i https://flashinfer.ai/whl/cu124/torch2.4
|
pip install -U setuptools
|
||||||
|
TORCH_CUDA_ARCH_LIST="8.0;8.6;8.9;9.0+PTX" FLASHINFER_ENABLE_AOT=1 pip install git+https://github.com/flashinfer-ai/flashinfer.git@v0.2.0.post1#egg=flashinfer --no-build-isolation
|
||||||
|
4100
server/poetry.lock
generated
4100
server/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,96 +1,87 @@
|
|||||||
[tool.poetry]
|
[project]
|
||||||
name = "text-generation-server"
|
name = "text-generation-server"
|
||||||
version = "2.0.5-dev0"
|
version = "2.0.5-dev0"
|
||||||
description = "Text Generation Inference Python gRPC Server"
|
description = "Text Generation Inference Python gRPC Server"
|
||||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.9"
|
||||||
[tool.poetry.scripts]
|
authors = [
|
||||||
text-generation-server = 'text_generation_server.cli:app'
|
{name = "Olivier Dehaene", email = "olivier@huggingface.co"},
|
||||||
|
{name = "Nicolas Patry", email = "nicolas@huggingface.co"},
|
||||||
[tool.poetry.dependencies]
|
|
||||||
python = ">=3.9,<3.13"
|
|
||||||
protobuf = ">=4.25.3,<6"
|
|
||||||
grpcio = "^1.51.1"
|
|
||||||
grpcio-status = "^1.51.1"
|
|
||||||
grpcio-reflection = "^1.51.1"
|
|
||||||
grpc-interceptor = "^0.15.4"
|
|
||||||
typer = "^0.12.5"
|
|
||||||
accelerate = {version = "^1.1.0", optional = true}
|
|
||||||
bitsandbytes = { version = "^0.43.0", optional = true }
|
|
||||||
safetensors = "^0.4.5"
|
|
||||||
loguru = "^0.7.2"
|
|
||||||
opentelemetry-api = "^1.27.0"
|
|
||||||
opentelemetry-exporter-otlp = "^1.27.0"
|
|
||||||
opentelemetry-instrumentation-grpc = "^0.48b0"
|
|
||||||
hf-transfer = "^0.1.2"
|
|
||||||
sentencepiece = "^0.2.0"
|
|
||||||
tokenizers = "^0.20.3"
|
|
||||||
huggingface-hub = "^0.23"
|
|
||||||
transformers = "^4.46.2"
|
|
||||||
einops = "^0.8.0"
|
|
||||||
texttable = { version = "^1.6.7", optional = true }
|
|
||||||
datasets = {version = "^2.21.0", optional = true}
|
|
||||||
peft = {version = "^0.13.2", optional = true}
|
|
||||||
torch = {version = "^2.4.1", optional = true}
|
|
||||||
scipy = "^1.13.1"
|
|
||||||
pillow = "^11.0.0"
|
|
||||||
outlines= {version = "^0.1.3", optional = true}
|
|
||||||
prometheus-client = ">=0.20.0,<0.22"
|
|
||||||
py-cpuinfo = "^9.0.0"
|
|
||||||
compressed-tensors = {version = "^0.7.1", optional = true}
|
|
||||||
# Remove later, temporary workaround for outlines.
|
|
||||||
numpy = "^1.26.4"
|
|
||||||
|
|
||||||
attention-kernels = [
|
|
||||||
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
|
||||||
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
|
||||||
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
|
||||||
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
|
||||||
]
|
]
|
||||||
marlin-kernels = [
|
dependencies = [
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
"einops>=0.8.0",
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
"grpc-interceptor>=0.15.4",
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
"grpcio>=1.67.0",
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
"grpcio-reflection>=1.67.0",
|
||||||
|
"grpcio-status>=1.67.0",
|
||||||
|
"hf-transfer>=0.1.8",
|
||||||
|
"loguru>=0.7.3",
|
||||||
|
"numpy>=1.26,<3",
|
||||||
|
"opentelemetry-api>=1.27.0",
|
||||||
|
"opentelemetry-exporter-otlp>=1.27.0",
|
||||||
|
"opentelemetry-instrumentation-grpc>=0.50b0",
|
||||||
|
"pillow>=11.1.0",
|
||||||
|
"prometheus-client>=0.21.0",
|
||||||
|
"protobuf>=5.28.3",
|
||||||
|
"py-cpuinfo>=9.0.0",
|
||||||
|
"rich>=13.8.1",
|
||||||
|
"safetensors>=0.4.5",
|
||||||
|
"scipy>=1.13.1",
|
||||||
|
"sentencepiece>=0.2.0",
|
||||||
|
"tokenizers>=0.20.3",
|
||||||
|
"typer>=0.15.1",
|
||||||
]
|
]
|
||||||
moe-kernels = [
|
|
||||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
|
||||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
|
||||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
|
||||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
|
||||||
]
|
|
||||||
rich = "^13.8.1"
|
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[project.scripts]
|
||||||
torch = ["torch"]
|
text-generation-server = "text_generation_server.cli:app"
|
||||||
accelerate = ["accelerate"]
|
|
||||||
attention = ["attention-kernels"]
|
[project.optional-dependencies]
|
||||||
bnb = ["bitsandbytes"]
|
accelerate = [
|
||||||
compressed-tensors = ["compressed-tensors"]
|
"accelerate>=1.2.1,<2",
|
||||||
marlin = ["marlin-kernels"]
|
]
|
||||||
|
bnb = [
|
||||||
|
"bitsandbytes>=0.45.0",
|
||||||
|
]
|
||||||
|
compressed-tensors = [
|
||||||
|
"compressed-tensors>=0.9.0",
|
||||||
|
]
|
||||||
|
peft = [
|
||||||
|
"peft>=0.14.0",
|
||||||
|
]
|
||||||
|
outlines = [
|
||||||
|
"outlines>=0.1.13",
|
||||||
|
]
|
||||||
|
dev = [
|
||||||
|
"grpcio-tools>=1.51.1,<2.0",
|
||||||
|
"pytest>=7.3.0,<8"
|
||||||
|
]
|
||||||
|
quantize = [
|
||||||
|
"texttable>=1.6.7,<2",
|
||||||
|
"datasets>=2.21,<3",
|
||||||
|
]
|
||||||
moe = [ "moe-kernels" ]
|
moe = [ "moe-kernels" ]
|
||||||
peft = ["peft"]
|
attention = [ "attention-kernels" ]
|
||||||
quantize = ["texttable", "datasets", "accelerate"]
|
marlin = [ "marlin-kernels" ]
|
||||||
outlines = ["outlines"]
|
gen = [
|
||||||
|
"grpcio-tools>=1.69.0",
|
||||||
|
"mypy-protobuf>=3.6.0",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.uv.sources]
|
||||||
grpcio-tools = "^1.51.1"
|
attention-kernels.url = "https://github.com/danieldk/attention-kernels/releases/download/v0.2.0.post2/attention_kernels-0.2.0.post2+cu123torch2.5-cp39-abi3-linux_x86_64.whl"
|
||||||
pytest = "^7.3.0"
|
marlin-kernels = [
|
||||||
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp39-cp39-linux_x86_64.whl", marker = "python_version == '3.9'" },
|
||||||
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp310-cp310-linux_x86_64.whl", marker = "python_version == '3.10'" },
|
||||||
[[tool.poetry.source]]
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp311-cp311-linux_x86_64.whl", marker = "python_version == '3.11'" },
|
||||||
name = "pytorch-gpu-src"
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp312-cp312-linux_x86_64.whl", marker = "python_version == '3.12'" },
|
||||||
url = "https://download.pytorch.org/whl/cu121"
|
]
|
||||||
priority = "explicit"
|
moe-kernels.url = "https://github.com/danieldk/moe-kernels/releases/download/v0.8.2/moe_kernels-0.8.2+cu123torch2.5-cp39-abi3-linux_x86_64.whl"
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
|
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
|
||||||
|
|
||||||
[build-system]
|
|
||||||
requires = [
|
|
||||||
"poetry-core>=1.0.0",
|
|
||||||
]
|
|
||||||
build-backend = "poetry.core.masonry.api"
|
|
||||||
|
|
||||||
[tool.isort]
|
[tool.isort]
|
||||||
profile = "black"
|
profile = "black"
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
include = ["text_generation_server*"]
|
||||||
|
@ -94,6 +94,8 @@ def test_get_mlp_weights_with_gate_up_proj():
|
|||||||
|
|
||||||
# assert the result
|
# assert the result
|
||||||
expected = {
|
expected = {
|
||||||
|
(3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc),
|
||||||
|
(3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj),
|
||||||
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
|
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
|
||||||
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
|
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
|
||||||
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
||||||
@ -188,6 +190,8 @@ def test_get_mlp_weights_llama_compatibility():
|
|||||||
result = get_mlp_weights(3, mock_layer)
|
result = get_mlp_weights(3, mock_layer)
|
||||||
|
|
||||||
expected = {
|
expected = {
|
||||||
|
(3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc),
|
||||||
|
(3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj),
|
||||||
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
|
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
|
||||||
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
|
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
|
||||||
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
||||||
@ -240,6 +244,8 @@ def test_get_mlp_weights_gemma_compatibility():
|
|||||||
result = get_mlp_weights(3, mock_layer)
|
result = get_mlp_weights(3, mock_layer)
|
||||||
|
|
||||||
expected = {
|
expected = {
|
||||||
|
(3, "c_fc"): ("model.layers.3.mlp.c_fc", mock_layer.mlp.c_fc),
|
||||||
|
(3, "c_proj"): ("model.layers.3.mlp.c_proj", mock_layer.mlp.c_proj),
|
||||||
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj),
|
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj),
|
||||||
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj),
|
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj),
|
||||||
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
||||||
|
@ -6,9 +6,11 @@ from collections import defaultdict
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
import torch
|
import torch
|
||||||
from peft import LoraConfig as _LoraConfig
|
from peft import LoraConfig as _LoraConfig
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
||||||
|
|
||||||
@ -203,8 +205,17 @@ class LoraWeights(AdapterWeights):
|
|||||||
lora_a_list = [None] * nlayers
|
lora_a_list = [None] * nlayers
|
||||||
lora_b_list = [None] * nlayers
|
lora_b_list = [None] * nlayers
|
||||||
|
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
for layer_id in range(nlayers):
|
for layer_id in range(nlayers):
|
||||||
key = (layer_id, layer_type)
|
key = (layer_id, layer_type)
|
||||||
|
if key not in target_to_layer:
|
||||||
|
# There is no layer of this type in the model
|
||||||
|
log_master(
|
||||||
|
logger.warning,
|
||||||
|
f"Key specified in lora weights but not found in base model: {key}",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
weight_name, layer = target_to_layer[key]
|
weight_name, layer = target_to_layer[key]
|
||||||
base_weight = layer.base_layer.linear.weight
|
base_weight = layer.base_layer.linear.weight
|
||||||
base_device = base_weight.device
|
base_device = base_weight.device
|
||||||
|
@ -9,6 +9,8 @@ from enum import Enum
|
|||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from text_generation_server.utils.adapter import parse_lora_adapters
|
from text_generation_server.utils.adapter import parse_lora_adapters
|
||||||
|
|
||||||
|
# Dummy change should cache hit.
|
||||||
|
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
|
@ -111,6 +111,8 @@ def paged_attention(
|
|||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
|
kv_cache_dtype = "fp8" if kv_cache.dtype == torch.float8_e4m3fn else "auto"
|
||||||
|
|
||||||
use_v1 = max_s <= 8192 and (
|
use_v1 = max_s <= 8192 and (
|
||||||
max_num_partitions == 1 or num_seqs * num_heads > 512
|
max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||||
)
|
)
|
||||||
@ -120,15 +122,16 @@ def paged_attention(
|
|||||||
query,
|
query,
|
||||||
kv_cache.key,
|
kv_cache.key,
|
||||||
kv_cache.value,
|
kv_cache.value,
|
||||||
kv_head_mapping,
|
kv_cache.key.shape[1],
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
block_size,
|
block_size,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
"auto",
|
kv_cache_dtype,
|
||||||
1.0,
|
kv_scales.key_scale_cpu,
|
||||||
|
kv_scales.value_scale_cpu,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Run PagedAttention V2.
|
# Run PagedAttention V2.
|
||||||
@ -153,15 +156,16 @@ def paged_attention(
|
|||||||
query,
|
query,
|
||||||
kv_cache.key,
|
kv_cache.key,
|
||||||
kv_cache.value,
|
kv_cache.value,
|
||||||
kv_head_mapping,
|
kv_cache.key.shape[1],
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
block_size,
|
block_size,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
"auto",
|
kv_cache_dtype,
|
||||||
1.0,
|
kv_scales.key_scale_cpu,
|
||||||
|
kv_scales.value_scale_cpu,
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -235,7 +239,6 @@ def attention(
|
|||||||
paged_kv_cache=(kv_cache.key, kv_cache.value),
|
paged_kv_cache=(kv_cache.key, kv_cache.value),
|
||||||
logits_soft_cap=softcap,
|
logits_soft_cap=softcap,
|
||||||
sm_scale=softmax_scale,
|
sm_scale=softmax_scale,
|
||||||
window_left=window_size_left,
|
|
||||||
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
|
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
|
||||||
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
|
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
|
||||||
)
|
)
|
||||||
|
@ -84,7 +84,7 @@ def use_prefill_with_paged_kv_state(
|
|||||||
|
|
||||||
token = prefill_with_paged_kv_state.set(state)
|
token = prefill_with_paged_kv_state.set(state)
|
||||||
try:
|
try:
|
||||||
state.begin_forward(
|
state.plan(
|
||||||
qo_indptr=cu_seqlens,
|
qo_indptr=cu_seqlens,
|
||||||
paged_kv_indptr=indptr,
|
paged_kv_indptr=indptr,
|
||||||
paged_kv_indices=block_tables,
|
paged_kv_indices=block_tables,
|
||||||
@ -99,7 +99,6 @@ def use_prefill_with_paged_kv_state(
|
|||||||
)
|
)
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
state.end_forward()
|
|
||||||
if token is not None:
|
if token is not None:
|
||||||
prefill_with_paged_kv_state.reset(token)
|
prefill_with_paged_kv_state.reset(token)
|
||||||
|
|
||||||
@ -200,7 +199,7 @@ def use_decode_state(
|
|||||||
token = decode_state.set(state)
|
token = decode_state.set(state)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
state.begin_forward(
|
state.plan(
|
||||||
indptr=indptr,
|
indptr=indptr,
|
||||||
indices=block_tables,
|
indices=block_tables,
|
||||||
last_page_len=last_page_len,
|
last_page_len=last_page_len,
|
||||||
@ -214,6 +213,5 @@ def use_decode_state(
|
|||||||
)
|
)
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
state.end_forward()
|
|
||||||
if token is not None:
|
if token is not None:
|
||||||
decode_state.reset(token)
|
decode_state.reset(token)
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
||||||
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from text_generation_server.models.globals import (
|
||||||
|
ATTENTION,
|
||||||
|
BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
SUPPORTS_WINDOWING = False
|
SUPPORTS_WINDOWING = False
|
||||||
|
|
||||||
@ -28,6 +31,22 @@ def attention(
|
|||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
|
if ATTENTION == "flashdecoding-ipex":
|
||||||
|
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||||
|
out,
|
||||||
|
query.contiguous() if query.device.type == "xpu" else query,
|
||||||
|
kv_cache.key,
|
||||||
|
kv_cache.value,
|
||||||
|
seqlen.cu_seqlen_q,
|
||||||
|
seqlen.cu_seqlen_k,
|
||||||
|
seqlen.max_q,
|
||||||
|
seqlen.max_k,
|
||||||
|
softmax_scale,
|
||||||
|
causal,
|
||||||
|
block_tables,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
ipex.llm.functional.varlen_attention(
|
ipex.llm.functional.varlen_attention(
|
||||||
query.contiguous() if query.device.type == "xpu" else query,
|
query.contiguous() if query.device.type == "xpu" else query,
|
||||||
key.contiguous() if key.device.type == "xpu" else key,
|
key.contiguous() if key.device.type == "xpu" else key,
|
||||||
@ -64,6 +83,23 @@ def paged_attention(
|
|||||||
raise NotImplementedError("softcap is not available in IPEX")
|
raise NotImplementedError("softcap is not available in IPEX")
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
|
if ATTENTION == "flashdecoding-ipex":
|
||||||
|
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||||
|
out,
|
||||||
|
query.contiguous() if query.device.type == "xpu" else query,
|
||||||
|
kv_cache.key,
|
||||||
|
kv_cache.value,
|
||||||
|
seqlen.cu_seqlen_q,
|
||||||
|
seqlen.cu_seqlen_k,
|
||||||
|
seqlen.max_q,
|
||||||
|
seqlen.max_k,
|
||||||
|
softmax_scale,
|
||||||
|
True,
|
||||||
|
block_tables,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||||
out,
|
out,
|
||||||
|
@ -52,12 +52,17 @@ class KVCache:
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
):
|
):
|
||||||
"""Construct the key-value cache for a layer."""
|
"""Construct the key-value cache for a layer."""
|
||||||
|
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
|
||||||
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and (
|
if not (
|
||||||
ATTENTION != "flashinfer" or SYSTEM != "cuda"
|
(ATTENTION == "flashinfer" and SYSTEM == "cuda")
|
||||||
|
or (ATTENTION == "paged" and SYSTEM in ("cuda", "rocm"))
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"FP8 KV cache is currently only supported for flashinfer on CUDA"
|
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on CUDA and ROCm. "
|
||||||
|
)
|
||||||
|
if SYSTEM == "rocm" and dtype == torch.float8_e5m2:
|
||||||
|
raise ValueError(
|
||||||
|
"float8_e5m2 FP8 KV cache is not supported on AMD ROCm"
|
||||||
)
|
)
|
||||||
|
|
||||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
@ -66,7 +71,9 @@ class KVCache:
|
|||||||
else:
|
else:
|
||||||
x = BLOCK_SIZE // element_size
|
x = BLOCK_SIZE // element_size
|
||||||
|
|
||||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
if ATTENTION in {"flashdecoding", "flashinfer"} or (
|
||||||
|
ATTENTION == "flashdecoding-ipex" and device.type == "xpu"
|
||||||
|
):
|
||||||
self.kv_cache = (
|
self.kv_cache = (
|
||||||
torch.empty(
|
torch.empty(
|
||||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||||
@ -80,6 +87,7 @@ class KVCache:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
||||||
|
# ipex cpu flashdecoding kernel and paged attention kernel share same layout
|
||||||
self.kv_cache = (
|
self.kv_cache = (
|
||||||
torch.empty(
|
torch.empty(
|
||||||
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||||
@ -110,21 +118,17 @@ class KVCache:
|
|||||||
"""Check if the cache can be scaled by the given scales."""
|
"""Check if the cache can be scaled by the given scales."""
|
||||||
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
|
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
|
||||||
return False
|
return False
|
||||||
elif (
|
elif self.dtype == torch.float8_e4m3fn and (
|
||||||
self.dtype == torch.float8_e4m3fn
|
(ATTENTION == "flashinfer" and SYSTEM == "cuda")
|
||||||
and ATTENTION == "flashinfer"
|
or (ATTENTION == "paged" and SYSTEM == "rocm")
|
||||||
and SYSTEM == "cuda"
|
|
||||||
):
|
):
|
||||||
log_once(
|
log_once(logger.info, "Using FP8 KV cache scales")
|
||||||
logger.info,
|
|
||||||
"Using FP8 KV cache scales",
|
|
||||||
)
|
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
# We have scales, but not the correct FP8 cache type, so warn once.
|
# We have scales, but not the correct FP8 cache type, so warn once.
|
||||||
log_once(
|
log_once(
|
||||||
logger.info,
|
logger.info,
|
||||||
"Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported",
|
"Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm",
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -158,7 +162,7 @@ class KVCache:
|
|||||||
key_cache = self.kv_cache[0]
|
key_cache = self.kv_cache[0]
|
||||||
value_cache = self.kv_cache[1]
|
value_cache = self.kv_cache[1]
|
||||||
|
|
||||||
if self.can_scale(kv_scales):
|
if self.can_scale(kv_scales) and SYSTEM == "cuda":
|
||||||
if kv_scales.key_scale_cpu != 1.0:
|
if kv_scales.key_scale_cpu != 1.0:
|
||||||
key = fp8_quantize(
|
key = fp8_quantize(
|
||||||
key.float(),
|
key.float(),
|
||||||
@ -187,8 +191,22 @@ class KVCache:
|
|||||||
shape = key_cache.shape
|
shape = key_cache.shape
|
||||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||||
|
elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu":
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
|
||||||
|
key, value, key_cache, value_cache, slots
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
paged_reshape_and_cache(key, value, key_cache, value_cache, slots)
|
paged_reshape_and_cache(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
slots,
|
||||||
|
kv_scales.key_scale_cpu,
|
||||||
|
kv_scales.value_scale_cpu,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def paged_reshape_and_cache(
|
def paged_reshape_and_cache(
|
||||||
@ -197,7 +215,10 @@ def paged_reshape_and_cache(
|
|||||||
key_cache: torch.Tensor,
|
key_cache: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
|
k_scale: float = 1.0,
|
||||||
|
v_scale: float = 1.0,
|
||||||
):
|
):
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
try:
|
try:
|
||||||
import attention_kernels
|
import attention_kernels
|
||||||
@ -205,8 +226,13 @@ def paged_reshape_and_cache(
|
|||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}"
|
f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
if key_cache.dtype == torch.float8_e4m3fn:
|
||||||
|
kv_cache_dtype = "fp8"
|
||||||
|
|
||||||
attention_kernels.reshape_and_cache(
|
attention_kernels.reshape_and_cache(
|
||||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale
|
||||||
)
|
)
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
try:
|
try:
|
||||||
@ -215,8 +241,15 @@ def paged_reshape_and_cache(
|
|||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
if key_cache.dtype == torch.float8_e4m3fn:
|
||||||
|
key_cache = key_cache.view(torch.uint8)
|
||||||
|
value_cache = value_cache.view(torch.uint8)
|
||||||
|
kv_cache_dtype = "fp8"
|
||||||
|
|
||||||
ops.reshape_and_cache(
|
ops.reshape_and_cache(
|
||||||
key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0
|
key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale
|
||||||
)
|
)
|
||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
@ -133,6 +133,15 @@ def paged_attention(
|
|||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
|
if kv_cache.dtype == torch.float8_e4m3fn:
|
||||||
|
key = kv_cache.key.view(torch.uint8)
|
||||||
|
value = kv_cache.value.view(torch.uint8)
|
||||||
|
kv_cache_dtype = "fp8"
|
||||||
|
else:
|
||||||
|
key = kv_cache.key
|
||||||
|
value = kv_cache.value
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
|
||||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
@ -147,8 +156,8 @@ def paged_attention(
|
|||||||
ops.paged_attention_v1(
|
ops.paged_attention_v1(
|
||||||
out,
|
out,
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
key,
|
||||||
kv_cache.value,
|
value,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
@ -156,24 +165,24 @@ def paged_attention(
|
|||||||
block_size,
|
block_size,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
"auto",
|
kv_cache_dtype,
|
||||||
1.0,
|
kv_scales.key_scale_cpu,
|
||||||
1.0,
|
kv_scales.value_scale_cpu,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Run PagedAttention V2.
|
# Run PagedAttention V2.
|
||||||
assert _PARTITION_SIZE % block_size == 0
|
assert _PARTITION_SIZE % block_size == 0
|
||||||
tmp_output = torch.empty(
|
tmp_output = torch.zeros(
|
||||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||||
dtype=out.dtype,
|
dtype=out.dtype,
|
||||||
device=out.device,
|
device=out.device,
|
||||||
)
|
)
|
||||||
exp_sums = torch.empty(
|
exp_sums = torch.zeros(
|
||||||
size=(num_seqs, num_heads, max_num_partitions),
|
size=(num_seqs, num_heads, max_num_partitions),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=out.device,
|
device=out.device,
|
||||||
)
|
)
|
||||||
max_logits = torch.empty_like(exp_sums)
|
max_logits = torch.zeros_like(exp_sums)
|
||||||
|
|
||||||
if not use_custom:
|
if not use_custom:
|
||||||
ops.paged_attention_v2(
|
ops.paged_attention_v2(
|
||||||
@ -182,8 +191,8 @@ def paged_attention(
|
|||||||
max_logits,
|
max_logits,
|
||||||
tmp_output,
|
tmp_output,
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
key,
|
||||||
kv_cache.value,
|
value,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
@ -191,9 +200,9 @@ def paged_attention(
|
|||||||
block_size,
|
block_size,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
"auto",
|
kv_cache_dtype,
|
||||||
1.0,
|
kv_scales.key_scale_cpu,
|
||||||
1.0,
|
kv_scales.value_scale_cpu,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ops.paged_attention_rocm(
|
ops.paged_attention_rocm(
|
||||||
@ -202,8 +211,8 @@ def paged_attention(
|
|||||||
max_logits,
|
max_logits,
|
||||||
tmp_output,
|
tmp_output,
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
key,
|
||||||
kv_cache.value,
|
value,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
@ -211,9 +220,9 @@ def paged_attention(
|
|||||||
block_size,
|
block_size,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
"auto",
|
kv_cache_dtype,
|
||||||
1.0,
|
kv_scales.key_scale_cpu,
|
||||||
1.0,
|
kv_scales.value_scale_cpu,
|
||||||
None,
|
None,
|
||||||
_PARTITION_SIZE,
|
_PARTITION_SIZE,
|
||||||
)
|
)
|
||||||
|
@ -3,8 +3,14 @@ from typing import List, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
|
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
|
||||||
|
|
||||||
from text_generation_server.layers.fp8 import Fp8Weight, _load_scalar_or_matrix_scale
|
from text_generation_server.layers.fp8 import (
|
||||||
|
Fp8Weight,
|
||||||
|
_load_scalar_or_matrix_scale,
|
||||||
|
requantize_with_max_scale,
|
||||||
|
normalize_e4m3fn_to_native_float8,
|
||||||
|
)
|
||||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
class W8ANFpLoader(WeightsLoader):
|
class W8ANFpLoader(WeightsLoader):
|
||||||
@ -47,11 +53,10 @@ class W8ANFpLoader(WeightsLoader):
|
|||||||
|
|
||||||
weight_scale = None
|
weight_scale = None
|
||||||
if self.load_weight_scale:
|
if self.load_weight_scale:
|
||||||
weight_scale = (
|
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
|
||||||
.reshape(-1)
|
if SYSTEM == "cuda":
|
||||||
.expand(w.shape[0])
|
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
|
||||||
)
|
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if self.load_input_scale:
|
if self.load_input_scale:
|
||||||
@ -87,6 +92,7 @@ class W8ANFpLoader(WeightsLoader):
|
|||||||
block_sizes=block_sizes,
|
block_sizes=block_sizes,
|
||||||
to_dtype=False,
|
to_dtype=False,
|
||||||
)
|
)
|
||||||
|
if SYSTEM == "cuda":
|
||||||
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
|
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
@ -141,6 +147,17 @@ class W8ANFpLoader(WeightsLoader):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.load_weight_scale and SYSTEM == "rocm":
|
||||||
|
w, weight_scale, input_scale = normalize_e4m3fn_to_native_float8(
|
||||||
|
w, weight_scale, input_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
if weight_scale.numel() == len(prefixes):
|
||||||
|
logical_widths = [x[0] for x in shapes]
|
||||||
|
w, weight_scale = requantize_with_max_scale(
|
||||||
|
w, weight_scale.to(weights.device), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=weight_scale,
|
weight_scale=weight_scale,
|
||||||
@ -153,11 +170,10 @@ class W8ANFpLoader(WeightsLoader):
|
|||||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
weight_scale = None
|
weight_scale = None
|
||||||
if self.load_weight_scale:
|
if self.load_weight_scale:
|
||||||
weight_scale = (
|
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
|
||||||
.reshape(-1)
|
if SYSTEM == "cuda":
|
||||||
.expand(w.shape[0])
|
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
|
||||||
)
|
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if self.load_input_scale:
|
if self.load_input_scale:
|
||||||
|
@ -19,6 +19,15 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
marlin_kernels = None
|
marlin_kernels = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8
|
||||||
|
except ImportError:
|
||||||
|
w8a8_block_fp8_matmul = None
|
||||||
|
per_token_group_quant_fp8 = None
|
||||||
|
|
||||||
|
quant_dtype: torch.dtype = (
|
||||||
|
torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
|
||||||
if SYSTEM == "cuda" and marlin_kernels is not None:
|
if SYSTEM == "cuda" and marlin_kernels is not None:
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
@ -35,7 +44,6 @@ def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
|
|
||||||
major, _ = torch.cuda.get_device_capability()
|
major, _ = torch.cuda.get_device_capability()
|
||||||
# Marlin is W8A16, use it when:
|
# Marlin is W8A16, use it when:
|
||||||
#
|
#
|
||||||
@ -49,18 +57,28 @@ def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
|
|||||||
# gives better decoding throughput on L4 and L40.
|
# gives better decoding throughput on L4 and L40.
|
||||||
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
|
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
|
||||||
|
|
||||||
|
if major == 8 and minor == 9:
|
||||||
|
log_once(
|
||||||
|
logger.info,
|
||||||
|
"GPU supports FP8, but using Marlin FP8 kernel for better performance",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log_once(
|
||||||
|
logger.info, "GPU does not support FP8, using Marlin FP8 kernel"
|
||||||
|
)
|
||||||
|
|
||||||
return GPTQMarlinFP8Linear
|
return GPTQMarlinFP8Linear
|
||||||
|
|
||||||
# On other systems let Torch decide if the hardware supports FP8.
|
# On other systems let Torch decide if the hardware supports FP8.
|
||||||
return Fp8Linear
|
return Fp8Linear
|
||||||
|
|
||||||
|
|
||||||
def normalize_e4m3fn_to_e4m3fnuz(
|
def normalize_e4m3fn_to_native_float8(
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
weight_scale: torch.Tensor,
|
weight_scale: torch.Tensor,
|
||||||
input_scale: Optional[torch.Tensor] = None,
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
assert weight.dtype == torch.float8_e4m3fn
|
if weight.dtype == torch.float8_e4m3fn and SYSTEM == "rocm":
|
||||||
# The bits pattern 10000000(-128) represents zero in e4m3fn
|
# The bits pattern 10000000(-128) represents zero in e4m3fn
|
||||||
# but NaN in e4m3fnuz. So here we set it to 0.
|
# but NaN in e4m3fnuz. So here we set it to 0.
|
||||||
# https://onnx.ai/onnx/technical/float8.html
|
# https://onnx.ai/onnx/technical/float8.html
|
||||||
@ -79,6 +97,39 @@ def normalize_e4m3fn_to_e4m3fnuz(
|
|||||||
return weight, weight_scale, input_scale
|
return weight, weight_scale, input_scale
|
||||||
|
|
||||||
|
|
||||||
|
def per_tensor_dequantize(
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
inv_scale: Union[float, torch.Tensor],
|
||||||
|
dtype: torch.dtype = torch.float16,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
fake_qweight = tensor.to(dtype)
|
||||||
|
dq_weight = fake_qweight * inv_scale
|
||||||
|
return dq_weight
|
||||||
|
|
||||||
|
|
||||||
|
def requantize_with_max_scale(
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
logical_widths: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Max scale to be used for requanitzation.
|
||||||
|
max_w_scale = weight_scale.max().float()
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
for idx, logical_width in enumerate(logical_widths):
|
||||||
|
end = start + logical_width
|
||||||
|
weight_dq = per_tensor_dequantize(
|
||||||
|
weight[start:end, :], weight_scale[idx], dtype
|
||||||
|
)
|
||||||
|
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
|
||||||
|
weight_dq, max_w_scale
|
||||||
|
)
|
||||||
|
start = end
|
||||||
|
|
||||||
|
return weight, max_w_scale_normalized
|
||||||
|
|
||||||
|
|
||||||
def fp8_quantize(
|
def fp8_quantize(
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
scale: Optional[torch.Tensor] = None,
|
scale: Optional[torch.Tensor] = None,
|
||||||
@ -96,7 +147,7 @@ def fp8_quantize(
|
|||||||
shape = weight.shape
|
shape = weight.shape
|
||||||
qweight, scale = marlin_kernels.scaled_fp8_quant(
|
qweight, scale = marlin_kernels.scaled_fp8_quant(
|
||||||
weight.reshape(-1, shape[-1]),
|
weight.reshape(-1, shape[-1]),
|
||||||
dtype=qdtype,
|
dtype=quant_dtype,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
scale_ub=scale_upper_bound,
|
scale_ub=scale_upper_bound,
|
||||||
# TODO: don't do this when we have to use the Torch kernel.
|
# TODO: don't do this when we have to use the Torch kernel.
|
||||||
@ -116,6 +167,8 @@ def fp8_quantize(
|
|||||||
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
||||||
scale = scale.float().reciprocal()
|
scale = scale.float().reciprocal()
|
||||||
else:
|
else:
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
scale = scale / 2.0
|
||||||
# Use reciprocal to avoid more expensive division.
|
# Use reciprocal to avoid more expensive division.
|
||||||
qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
|
qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
|
||||||
|
|
||||||
@ -124,7 +177,7 @@ def fp8_quantize(
|
|||||||
qweight = qweight.to(qdtype)
|
qweight = qweight.to(qdtype)
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)
|
qweight, scale, _ = normalize_e4m3fn_to_native_float8(qweight, scale)
|
||||||
|
|
||||||
return qweight, scale
|
return qweight, scale
|
||||||
|
|
||||||
@ -132,26 +185,42 @@ def fp8_quantize(
|
|||||||
class HybridFP8UnquantLoader(WeightsLoader):
|
class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
"""Weight loader that loads FP8 and unquantized Torch tensors."""
|
"""Weight loader that loads FP8 and unquantized Torch tensors."""
|
||||||
|
|
||||||
def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
|
def __init__(
|
||||||
|
self,
|
||||||
|
activation_scale_ub: Optional[float],
|
||||||
|
to_fp8: bool,
|
||||||
|
weight_block_size: Optional[List[int]] = None,
|
||||||
|
):
|
||||||
self.activation_scale_ub = activation_scale_ub
|
self.activation_scale_ub = activation_scale_ub
|
||||||
self.to_fp8 = to_fp8
|
self.to_fp8 = to_fp8
|
||||||
|
self.weight_block_size = weight_block_size
|
||||||
|
|
||||||
def get_weights(self, weights: "Weights", prefix: str):
|
def get_weights(self, weights: "Weights", prefix: str):
|
||||||
w = weights.get_tensor(f"{prefix}.weight")
|
w = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
# FP8 branch
|
if self.weight_block_size is not None:
|
||||||
scale = (
|
scale = weights.get_tensor(f"{prefix}.weight_scale_inv")
|
||||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
return Fp8Weight(
|
||||||
.reshape(-1)
|
weight=w,
|
||||||
.expand(w.shape[0])
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
weight_block_size=self.weight_block_size,
|
||||||
)
|
)
|
||||||
|
# FP8 branch
|
||||||
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
|
||||||
|
if SYSTEM == "cuda":
|
||||||
|
scale.reshape(-1).expand(w.shape[0])
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if weights.has_tensor(f"{prefix}.input_scale"):
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
input_scale = weights.get_tensor(
|
input_scale = (
|
||||||
f"{prefix}.input_scale", to_dtype=False
|
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
|
||||||
).reshape(-1)
|
.reshape(-1)
|
||||||
|
.max()
|
||||||
|
)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
@ -178,6 +247,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
|
||||||
if scale.numel() > 1:
|
if scale.numel() > 1:
|
||||||
scale = weights.get_packed_sharded(
|
scale = weights.get_packed_sharded(
|
||||||
f"{prefix}.weight_scale",
|
f"{prefix}.weight_scale",
|
||||||
@ -185,6 +255,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
block_sizes=block_sizes,
|
block_sizes=block_sizes,
|
||||||
to_dtype=False,
|
to_dtype=False,
|
||||||
)
|
)
|
||||||
|
if SYSTEM == "cuda":
|
||||||
scale = scale.reshape(-1).expand(w.shape[0])
|
scale = scale.reshape(-1).expand(w.shape[0])
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
@ -225,6 +296,21 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
|
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
if self.weight_block_size is not None:
|
||||||
|
scale = [
|
||||||
|
weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False)
|
||||||
|
for p in prefixes
|
||||||
|
]
|
||||||
|
scale = torch.cat(scale, dim=dim)
|
||||||
|
scale = scale.to(weights.device)
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
weight_block_size=self.weight_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
scale = [
|
scale = [
|
||||||
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
||||||
for p, shape in zip(prefixes, shapes)
|
for p, shape in zip(prefixes, shapes)
|
||||||
@ -243,6 +329,17 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
w, scale, input_scale = normalize_e4m3fn_to_native_float8(
|
||||||
|
w, scale, input_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
if scale.numel() == len(prefixes):
|
||||||
|
logical_widths = [x[0] for x in shapes]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.to(weights.device), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
@ -259,16 +356,30 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
scale = (
|
if self.weight_block_size is not None:
|
||||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
# XXX: Yes the weights is named scale_inv, but corresponds to scale it seems.
|
||||||
.reshape(-1)
|
scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
|
||||||
.expand(w.shape[0])
|
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
weight_block_size=self.weight_block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
|
||||||
|
if SYSTEM == "cuda":
|
||||||
|
scale = scale.reshape(-1).expand(w.shape[0])
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if weights.has_tensor(f"{prefix}.input_scale"):
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
input_scale = weights.get_tensor(
|
input_scale = (
|
||||||
f"{prefix}.input_scale", to_dtype=False
|
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
|
||||||
).reshape(-1)
|
.reshape(-1)
|
||||||
|
.max()
|
||||||
|
)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
@ -291,6 +402,7 @@ class Fp8Weight(Weight):
|
|||||||
input_scale: Optional[torch.Tensor] = None
|
input_scale: Optional[torch.Tensor] = None
|
||||||
activation_scale_ub: Optional[float] = None
|
activation_scale_ub: Optional[float] = None
|
||||||
force_w8a16: bool = False
|
force_w8a16: bool = False
|
||||||
|
weight_block_size: Optional[List[int]] = None
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
if self.weight_scale is None:
|
if self.weight_scale is None:
|
||||||
@ -307,6 +419,7 @@ class Fp8Weight(Weight):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
input_scale=self.input_scale,
|
input_scale=self.input_scale,
|
||||||
scale_upper_bound=self.activation_scale_ub,
|
scale_upper_bound=self.activation_scale_ub,
|
||||||
|
weight_block_size=self.weight_block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -321,19 +434,21 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
input_scale: Optional[torch.Tensor] = None,
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
scale_upper_bound: Optional[float] = None,
|
scale_upper_bound: Optional[float] = None,
|
||||||
|
weight_block_size: Optional[List[int]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if CUTLASS_FP8_AVAILABLE:
|
if CUTLASS_FP8_AVAILABLE:
|
||||||
log_once(logger.info, "Using cutlass w8a8 kernels")
|
log_once(logger.info, "Using cutlass w8a8 kernels")
|
||||||
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
|
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
|
||||||
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
qweight, scale, input_scale = normalize_e4m3fn_to_native_float8(
|
||||||
weight=qweight, weight_scale=scale
|
weight=qweight, weight_scale=scale, input_scale=input_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.qweight = qweight
|
self.qweight = qweight
|
||||||
self.scale = scale.float()
|
self.scale = scale.float()
|
||||||
self.input_scale = input_scale.float() if input_scale is not None else None
|
self.input_scale = input_scale.float() if input_scale is not None else None
|
||||||
|
self.weight_block_size = weight_block_size
|
||||||
|
|
||||||
if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None:
|
if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None:
|
||||||
self.scale_upper_bound = torch.tensor(
|
self.scale_upper_bound = torch.tensor(
|
||||||
@ -367,6 +482,7 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
) -> "Fp8Linear":
|
) -> "Fp8Linear":
|
||||||
input_scale = kwargs.get("input_scale", None)
|
input_scale = kwargs.get("input_scale", None)
|
||||||
scale_upper_bound = kwargs.get("scale_upper_bound", None)
|
scale_upper_bound = kwargs.get("scale_upper_bound", None)
|
||||||
|
weight_block_size = kwargs.get("weight_block_size", None)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
qweight=weight,
|
qweight=weight,
|
||||||
@ -375,6 +491,7 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
scale_upper_bound=scale_upper_bound,
|
scale_upper_bound=scale_upper_bound,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
weight_block_size=weight_block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -386,6 +503,25 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
return cls._device_identity_cache[device]
|
return cls._device_identity_cache[device]
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.weight_block_size is not None:
|
||||||
|
# https://arxiv.org/pdf/2412.19437
|
||||||
|
# At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and
|
||||||
|
# scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
|
||||||
|
# group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
|
||||||
|
# channels).
|
||||||
|
qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
|
||||||
|
output = w8a8_block_fp8_matmul(
|
||||||
|
qinput,
|
||||||
|
self.qweight,
|
||||||
|
scale,
|
||||||
|
self.scale,
|
||||||
|
self.weight_block_size,
|
||||||
|
output_dtype=input.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.bias is not None:
|
||||||
|
output = output + self.bias
|
||||||
|
return output.to(dtype=input.dtype)
|
||||||
if CUTLASS_FP8_AVAILABLE:
|
if CUTLASS_FP8_AVAILABLE:
|
||||||
# cutlass FP8 supports per-token scales, so get non-scalar scales.
|
# cutlass FP8 supports per-token scales, so get non-scalar scales.
|
||||||
qinput, scale = fp8_quantize(
|
qinput, scale = fp8_quantize(
|
||||||
@ -443,6 +579,9 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
|
|
||||||
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
|
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
|
||||||
scale = weights.get_tensor(prefix, to_dtype=False)
|
scale = weights.get_tensor(prefix, to_dtype=False)
|
||||||
|
|
||||||
if scale.numel() > 1:
|
if scale.numel() > 1:
|
||||||
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
|
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
|
||||||
|
elif SYSTEM == "rocm":
|
||||||
|
return scale.reshape(-1)
|
||||||
return scale.reshape(-1).expand(shape[0])
|
return scale.reshape(-1).expand(shape[0])
|
||||||
|
@ -956,15 +956,24 @@ def quantize(
|
|||||||
|
|
||||||
pack(model, quantizers, bits, groupsize)
|
pack(model, quantizers, bits, groupsize)
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
from transformers.modeling_utils import shard_checkpoint
|
from huggingface_hub import split_torch_state_dict_into_shards
|
||||||
|
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
||||||
|
|
||||||
max_shard_size = "10GB"
|
max_shard_size = "10GB"
|
||||||
shards, index = shard_checkpoint(
|
state_dict_split = split_torch_state_dict_into_shards(
|
||||||
state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors"
|
state_dict,
|
||||||
|
filename_pattern="model.safetensors",
|
||||||
|
max_shard_size=max_shard_size,
|
||||||
)
|
)
|
||||||
|
index = None
|
||||||
|
if state_dict_split.is_sharded:
|
||||||
|
index = {
|
||||||
|
"metadata": state_dict_split.metadata,
|
||||||
|
"weight_map": state_dict_split.tensor_to_filename,
|
||||||
|
}
|
||||||
|
shards = state_dict_split.filename_to_tensors
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
for shard_file, shard in shards.items():
|
for shard_file, shard in shards.items():
|
||||||
save_file(
|
save_file(
|
||||||
|
@ -2,14 +2,12 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from loguru import logger
|
|
||||||
from text_generation_server.layers.fp8 import fp8_quantize
|
from text_generation_server.layers.fp8 import fp8_quantize
|
||||||
from text_generation_server.layers.marlin.gptq import _check_valid_shape
|
from text_generation_server.layers.marlin.gptq import _check_valid_shape
|
||||||
from text_generation_server.layers.marlin.util import (
|
from text_generation_server.layers.marlin.util import (
|
||||||
_check_marlin_kernels,
|
_check_marlin_kernels,
|
||||||
permute_scales,
|
permute_scales,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.log import log_once
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import marlin_kernels
|
import marlin_kernels
|
||||||
@ -36,8 +34,6 @@ class GPTQMarlinFP8Linear(nn.Module):
|
|||||||
_check_marlin_kernels()
|
_check_marlin_kernels()
|
||||||
assert marlin_kernels is not None
|
assert marlin_kernels is not None
|
||||||
|
|
||||||
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
|
|
||||||
|
|
||||||
scales = scales.unsqueeze(0)
|
scales = scales.unsqueeze(0)
|
||||||
if scales.shape[1] == 1:
|
if scales.shape[1] == 1:
|
||||||
out_features, in_features = qweight.shape
|
out_features, in_features = qweight.shape
|
||||||
|
@ -16,6 +16,7 @@ from text_generation_server.layers.moe.gptq_marlin import (
|
|||||||
can_use_marlin_moe_gemm,
|
can_use_marlin_moe_gemm,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
||||||
|
from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
from text_generation_server.utils.weights import (
|
from text_generation_server.utils.weights import (
|
||||||
@ -25,7 +26,7 @@ from text_generation_server.utils.weights import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex":
|
||||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
from .fused_moe_ipex import fused_topk, grouped_topk
|
||||||
else:
|
else:
|
||||||
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
||||||
|
|
||||||
@ -51,6 +52,8 @@ class MoELayer(Protocol):
|
|||||||
up_proj_name: str = "up_proj",
|
up_proj_name: str = "up_proj",
|
||||||
down_proj_name: str = "down_proj",
|
down_proj_name: str = "down_proj",
|
||||||
hidden_act: str = "silu",
|
hidden_act: str = "silu",
|
||||||
|
scoring_func: Optional[str] = None,
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
): ...
|
): ...
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -80,9 +83,14 @@ class DenseMoELayer(nn.Module):
|
|||||||
up_proj_name: str = "up_proj",
|
up_proj_name: str = "up_proj",
|
||||||
down_proj_name: str = "down_proj",
|
down_proj_name: str = "down_proj",
|
||||||
hidden_act: str = "silu",
|
hidden_act: str = "silu",
|
||||||
|
scoring_func: Optional[str] = None,
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
assert scoring_func is None, "scoring func is not handled"
|
||||||
|
assert e_score_correction_bias is None, "scoring correction bias is not handled"
|
||||||
|
|
||||||
log_once(
|
log_once(
|
||||||
logger.info,
|
logger.info,
|
||||||
"No fused layers are available for this model type, using (slower) dense MoE layer",
|
"No fused layers are available for this model type, using (slower) dense MoE layer",
|
||||||
@ -139,10 +147,6 @@ class DenseMoELayer(nn.Module):
|
|||||||
)
|
)
|
||||||
for i in range(self.n_experts)
|
for i in range(self.n_experts)
|
||||||
]
|
]
|
||||||
if SYSTEM == "ipex":
|
|
||||||
self.ipex_fused_moe = GatedMLPMOE(
|
|
||||||
W13=self.gate_proj, W2=self.down_proj, W3=self.up_proj, use_prepack=True
|
|
||||||
)
|
|
||||||
|
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
@ -155,17 +159,6 @@ class DenseMoELayer(nn.Module):
|
|||||||
input_shape = x.shape
|
input_shape = x.shape
|
||||||
x = x.view(-1, input_shape[-1])
|
x = x.view(-1, input_shape[-1])
|
||||||
|
|
||||||
if SYSTEM == "ipex":
|
|
||||||
return self.ipex_fused_moe(
|
|
||||||
hidden_states=x,
|
|
||||||
router_logits=gating_output,
|
|
||||||
top_k=self.topk,
|
|
||||||
renormalize=self.renormalize,
|
|
||||||
use_grouped_topk=self.n_expert_group is not None,
|
|
||||||
num_expert_group=self.n_expert_group,
|
|
||||||
topk_group=self.topk_group,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.n_expert_group is not None and self.topk_group is not None:
|
if self.n_expert_group is not None and self.topk_group is not None:
|
||||||
topk_weights, topk_ids = grouped_topk(
|
topk_weights, topk_ids = grouped_topk(
|
||||||
x,
|
x,
|
||||||
@ -213,16 +206,23 @@ class SparseMoELayer(nn.Module):
|
|||||||
topk: int,
|
topk: int,
|
||||||
topk_group: Optional[int],
|
topk_group: Optional[int],
|
||||||
weights: Weights,
|
weights: Weights,
|
||||||
|
scoring_func: Optional[str] = "softmax",
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
gate_proj_name: str = "gate_proj",
|
gate_proj_name: str = "gate_proj",
|
||||||
up_proj_name: str = "up_proj",
|
up_proj_name: str = "up_proj",
|
||||||
down_proj_name: str = "down_proj",
|
down_proj_name: str = "down_proj",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
isinstance(weights.loader, DefaultWeightsLoader)
|
isinstance(weights.loader, DefaultWeightsLoader)
|
||||||
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
||||||
) or isinstance(weights.loader, HybridFP8UnquantLoader):
|
) or isinstance(weights.loader, HybridFP8UnquantLoader):
|
||||||
|
if (
|
||||||
|
isinstance(weights.loader, HybridFP8UnquantLoader)
|
||||||
|
and weights.loader.to_fp8
|
||||||
|
):
|
||||||
|
cls = FP8SparseMoELayer
|
||||||
|
else:
|
||||||
cls = UnquantizedSparseMoELayer
|
cls = UnquantizedSparseMoELayer
|
||||||
elif isinstance(
|
elif isinstance(
|
||||||
weights.loader, GPTQMarlinWeightsLoader
|
weights.loader, GPTQMarlinWeightsLoader
|
||||||
@ -250,6 +250,8 @@ class SparseMoELayer(nn.Module):
|
|||||||
topk=topk,
|
topk=topk,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
gate_proj_name=gate_proj_name,
|
gate_proj_name=gate_proj_name,
|
||||||
up_proj_name=up_proj_name,
|
up_proj_name=up_proj_name,
|
||||||
down_proj_name=down_proj_name,
|
down_proj_name=down_proj_name,
|
||||||
|
173
server/text_generation_server/layers/moe/fp8.py
Normal file
173
server/text_generation_server/layers/moe/fp8.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from text_generation_server.utils.weights import Weights
|
||||||
|
from text_generation_server.layers.fp8 import (
|
||||||
|
Fp8Weight,
|
||||||
|
fp8_quantize,
|
||||||
|
quant_dtype,
|
||||||
|
normalize_e4m3fn_to_native_float8,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from moe_kernels.fused_moe import fused_moe
|
||||||
|
except Exception:
|
||||||
|
fused_moe = None
|
||||||
|
|
||||||
|
|
||||||
|
class FP8SparseMoELayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_expert_group: Optional[int],
|
||||||
|
n_experts: int,
|
||||||
|
prefix: str,
|
||||||
|
renormalize: bool,
|
||||||
|
topk: int,
|
||||||
|
topk_group: Optional[int],
|
||||||
|
weights: Weights,
|
||||||
|
scoring_func: Optional[str] = "softmax",
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
|
gate_proj_name: str = "gate_proj",
|
||||||
|
up_proj_name: str = "up_proj",
|
||||||
|
down_proj_name: str = "down_proj",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert (n_expert_group is None) == (
|
||||||
|
topk_group is None
|
||||||
|
), "n_expert_group and topk_group must both be None or have some value"
|
||||||
|
|
||||||
|
self.n_expert_group = n_expert_group
|
||||||
|
self.topk = topk
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.renormalize = renormalize
|
||||||
|
self.weight_block_size = weights.weights_loader.weight_block_size
|
||||||
|
self.scoring_func = scoring_func
|
||||||
|
self.e_score_correction_bias = e_score_correction_bias
|
||||||
|
|
||||||
|
(
|
||||||
|
self.gate_up_proj,
|
||||||
|
self.gate_up_proj_weight_scale,
|
||||||
|
self.gate_up_proj_input_scale,
|
||||||
|
) = _load_expert_multi_weights_col(
|
||||||
|
prefix=prefix,
|
||||||
|
n_experts=n_experts,
|
||||||
|
gate_proj_name=gate_proj_name,
|
||||||
|
up_proj_name=up_proj_name,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
|
||||||
|
_load_expert_weights_row(
|
||||||
|
prefix=prefix,
|
||||||
|
n_experts=n_experts,
|
||||||
|
name=down_proj_name,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
|
return fused_moe(
|
||||||
|
x,
|
||||||
|
w1=self.gate_up_proj,
|
||||||
|
w2=self.down_proj,
|
||||||
|
gating_output=gating_output,
|
||||||
|
topk=self.topk,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
inplace=True,
|
||||||
|
use_grouped_topk=self.n_expert_group is not None,
|
||||||
|
num_expert_group=self.n_expert_group,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
scoring_func=self.scoring_func,
|
||||||
|
e_score_correction_bias=self.e_score_correction_bias,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
w1_scale=self.gate_up_proj_weight_scale,
|
||||||
|
w2_scale=self.down_proj_weight_scale,
|
||||||
|
a1_scale=self.gate_up_proj_input_scale,
|
||||||
|
a2_scale=self.down_proj_input_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_weights(
|
||||||
|
get_weight_fn,
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
all_weight = None
|
||||||
|
all_weight_scales = None
|
||||||
|
max_input_scale = None
|
||||||
|
|
||||||
|
for i in range(n_experts):
|
||||||
|
weight = get_weight_fn(prefix, i, name, weights)
|
||||||
|
|
||||||
|
assert isinstance(weight, Fp8Weight)
|
||||||
|
|
||||||
|
if all_weight is None:
|
||||||
|
all_weight = torch.empty(
|
||||||
|
(n_experts,) + weight.weight.shape,
|
||||||
|
dtype=quant_dtype,
|
||||||
|
device=weight.weight.device,
|
||||||
|
)
|
||||||
|
if all_weight_scales is None:
|
||||||
|
all_weight_scales = torch.empty(
|
||||||
|
(n_experts,) + weight.weight_scale.shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=weight.weight.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}:
|
||||||
|
all_weight[i], all_weight_scales[i], current_input_scale = (
|
||||||
|
normalize_e4m3fn_to_native_float8(
|
||||||
|
weight.weight, weight.weight_scale, weight.input_scale
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if current_input_scale is not None:
|
||||||
|
if max_input_scale is None or current_input_scale > max_input_scale:
|
||||||
|
max_input_scale = current_input_scale
|
||||||
|
else:
|
||||||
|
all_weight[i], all_weight_scales[i] = fp8_quantize(
|
||||||
|
weight.weight, scalar=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert all_weight is not None
|
||||||
|
|
||||||
|
return all_weight, all_weight_scales, max_input_scale
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_multi_weights_col(
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
gate_proj_name: str,
|
||||||
|
up_proj_name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
def get_weight_fn(prefix, i, name, weights):
|
||||||
|
return weights.get_multi_weights_col(
|
||||||
|
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return _load_expert_weights(
|
||||||
|
get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_weights_row(
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
def get_weight_fn(prefix, i, name, weights):
|
||||||
|
return weights.get_weights_row(f"{prefix}.{i}.{name}")
|
||||||
|
|
||||||
|
return _load_expert_weights(
|
||||||
|
get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights
|
||||||
|
)
|
65
server/text_generation_server/layers/moe/fused_moe_ipex.py
Normal file
65
server/text_generation_server/layers/moe/fused_moe_ipex.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def grouped_topk(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
num_expert_group: int = 0,
|
||||||
|
topk_group: int = 0,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
scores = torch.softmax(gating_output, dim=-1)
|
||||||
|
num_token = scores.shape[0]
|
||||||
|
group_scores = (
|
||||||
|
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||||
|
) # [n, n_group]
|
||||||
|
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
||||||
|
1
|
||||||
|
] # [n, top_k_group]
|
||||||
|
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||||
|
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||||
|
score_mask = (
|
||||||
|
group_mask.unsqueeze(-1)
|
||||||
|
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
||||||
|
.reshape(num_token, -1)
|
||||||
|
) # [n, e]
|
||||||
|
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||||
|
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||||
|
|
||||||
|
if renormalize:
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
|
def fused_topk(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
topk_weights = torch.nn.functional.softmax(
|
||||||
|
gating_output, dim=1, dtype=torch.float32
|
||||||
|
)
|
||||||
|
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
||||||
|
if renormalize:
|
||||||
|
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
return topk_weights, topk_ids
|
@ -69,7 +69,11 @@ class GPTQMarlinSparseMoELayer(nn.Module):
|
|||||||
gate_proj_name: str = "gate_proj",
|
gate_proj_name: str = "gate_proj",
|
||||||
up_proj_name: str = "up_proj",
|
up_proj_name: str = "up_proj",
|
||||||
down_proj_name: str = "down_proj",
|
down_proj_name: str = "down_proj",
|
||||||
|
scoring_func: Optional[str] = None,
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
):
|
):
|
||||||
|
assert scoring_func == "softmax", f"scoring func {scoring_func} is not handled"
|
||||||
|
assert e_score_correction_bias is None, "scoring correction bias is not handled"
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
|
@ -23,6 +23,8 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||||||
topk: int,
|
topk: int,
|
||||||
topk_group: Optional[int],
|
topk_group: Optional[int],
|
||||||
weights: Weights,
|
weights: Weights,
|
||||||
|
scoring_func: Optional[str] = "softmax",
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
gate_proj_name: str = "gate_proj",
|
gate_proj_name: str = "gate_proj",
|
||||||
up_proj_name: str = "up_proj",
|
up_proj_name: str = "up_proj",
|
||||||
down_proj_name: str = "down_proj",
|
down_proj_name: str = "down_proj",
|
||||||
@ -37,6 +39,9 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||||||
self.topk = topk
|
self.topk = topk
|
||||||
self.topk_group = topk_group
|
self.topk_group = topk_group
|
||||||
self.renormalize = renormalize
|
self.renormalize = renormalize
|
||||||
|
self.weight_block_size = weights.weights_loader.weight_block_size
|
||||||
|
self.scoring_func = scoring_func
|
||||||
|
self.e_score_correction_bias = e_score_correction_bias
|
||||||
|
|
||||||
self.gate_up_proj = _load_expert_multi_weights_col(
|
self.gate_up_proj = _load_expert_multi_weights_col(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
@ -58,17 +63,7 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "ipex":
|
||||||
return fused_moe(
|
|
||||||
x,
|
|
||||||
self.gate_up_proj,
|
|
||||||
self.down_proj,
|
|
||||||
gating_output,
|
|
||||||
self.topk,
|
|
||||||
renormalize=self.renormalize,
|
|
||||||
inplace=True,
|
|
||||||
)
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
return self.ipex_fused_moe(
|
return self.ipex_fused_moe(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
router_logits=gating_output,
|
router_logits=gating_output,
|
||||||
@ -78,7 +73,6 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||||||
num_expert_group=self.n_expert_group,
|
num_expert_group=self.n_expert_group,
|
||||||
topk_group=self.topk_group,
|
topk_group=self.topk_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
return fused_moe(
|
return fused_moe(
|
||||||
x,
|
x,
|
||||||
w1=self.gate_up_proj,
|
w1=self.gate_up_proj,
|
||||||
@ -90,6 +84,8 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||||||
use_grouped_topk=self.n_expert_group is not None,
|
use_grouped_topk=self.n_expert_group is not None,
|
||||||
num_expert_group=self.n_expert_group,
|
num_expert_group=self.n_expert_group,
|
||||||
topk_group=self.topk_group,
|
topk_group=self.topk_group,
|
||||||
|
scoring_func=self.scoring_func,
|
||||||
|
e_score_correction_bias=self.e_score_correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,10 +16,12 @@ from transformers.models.auto import modeling_auto
|
|||||||
from huggingface_hub import hf_hub_download, HfApi
|
from huggingface_hub import hf_hub_download, HfApi
|
||||||
from typing import Optional, List, Dict
|
from typing import Optional, List, Dict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import transformers
|
||||||
|
|
||||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
|
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
|
||||||
|
|
||||||
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
||||||
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
||||||
MPTForCausalLM,
|
MPTForCausalLM,
|
||||||
@ -87,6 +89,10 @@ try:
|
|||||||
FlashDeepseekV2ForCausalLM,
|
FlashDeepseekV2ForCausalLM,
|
||||||
DeepseekV2Config,
|
DeepseekV2Config,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_deepseek_v3_modeling import (
|
||||||
|
FlashDeepseekV3ForCausalLM,
|
||||||
|
DeepseekV3Config,
|
||||||
|
)
|
||||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
FlashLlamaForCausalLM,
|
FlashLlamaForCausalLM,
|
||||||
)
|
)
|
||||||
@ -178,6 +184,14 @@ except ImportError as e:
|
|||||||
if MAMBA_AVAILABLE:
|
if MAMBA_AVAILABLE:
|
||||||
__all__.append(Mamba)
|
__all__.append(Mamba)
|
||||||
|
|
||||||
|
FLASH_TRANSFORMERS_BACKEND = torch.cuda.is_available()
|
||||||
|
try:
|
||||||
|
from text_generation_server.models.transformers_flash_causal_lm import (
|
||||||
|
TransformersFlashCausalLM,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
FLASH_TRANSFORMERS_BACKEND = False
|
||||||
|
|
||||||
|
|
||||||
class ModelType(enum.Enum):
|
class ModelType(enum.Enum):
|
||||||
DEEPSEEK_V2 = {
|
DEEPSEEK_V2 = {
|
||||||
@ -185,6 +199,11 @@ class ModelType(enum.Enum):
|
|||||||
"name": "Deepseek V2",
|
"name": "Deepseek V2",
|
||||||
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
|
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
|
||||||
}
|
}
|
||||||
|
DEEPSEEK_V3 = {
|
||||||
|
"type": "deepseek_v3",
|
||||||
|
"name": "Deepseek V3",
|
||||||
|
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V3",
|
||||||
|
}
|
||||||
IDEFICS2 = {
|
IDEFICS2 = {
|
||||||
"type": "idefics2",
|
"type": "idefics2",
|
||||||
"name": "Idefics 2",
|
"name": "Idefics 2",
|
||||||
@ -632,6 +651,40 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
elif model_type == DEEPSEEK_V3:
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
head_size = max(
|
||||||
|
config_dict.get("qk_nope_dim", 128)
|
||||||
|
+ config_dict.get("qk_rope_dim", 64),
|
||||||
|
config_dict.get("v_head_dim", 128),
|
||||||
|
)
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashDeepseekV3ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=DeepseekV3Config,
|
||||||
|
head_size=head_size,
|
||||||
|
)
|
||||||
|
elif sharded:
|
||||||
|
raise NotImplementedError(
|
||||||
|
FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V3")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return CausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
elif model_type == MAMBA:
|
elif model_type == MAMBA:
|
||||||
return Mamba(
|
return Mamba(
|
||||||
model_id,
|
model_id,
|
||||||
@ -683,7 +736,7 @@ def get_model(
|
|||||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return CausalLM.fallback(
|
return transformers_causal_lm_class.fallback(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -838,7 +891,7 @@ def get_model(
|
|||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return CausalLM.fallback(
|
return TransformersFlashCausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -888,12 +941,43 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif (
|
elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE:
|
||||||
model_type == LLAMA
|
if FLASH_ATTENTION:
|
||||||
or model_type == BAICHUAN
|
return FlashCausalLM(
|
||||||
or model_type == PHI3
|
model_id=model_id,
|
||||||
or model_type == GRANITE
|
model_class=FlashLlamaForCausalLM,
|
||||||
):
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
return TransformersFlashCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
elif sharded:
|
||||||
|
raise NotImplementedError(
|
||||||
|
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return CausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif model_type == BAICHUAN:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -919,6 +1003,7 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == GEMMA:
|
if model_type == GEMMA:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
@ -934,6 +1019,15 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
return TransformersFlashCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
|
||||||
else:
|
else:
|
||||||
@ -985,6 +1079,15 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
return TransformersFlashCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
|
||||||
else:
|
else:
|
||||||
@ -1088,6 +1191,15 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
return TransformersFlashCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
|
||||||
else:
|
else:
|
||||||
@ -1113,6 +1225,15 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
return TransformersFlashCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
|
||||||
else:
|
else:
|
||||||
@ -1138,6 +1259,15 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
return TransformersFlashCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
|
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
|
||||||
@ -1165,6 +1295,15 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
return TransformersFlashCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
|
||||||
else:
|
else:
|
||||||
@ -1314,8 +1453,6 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))
|
||||||
|
|
||||||
if sharded:
|
|
||||||
raise NotImplementedError("sharded is not supported for AutoModel")
|
|
||||||
if quantize == "gptq":
|
if quantize == "gptq":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
@ -1328,8 +1465,19 @@ def get_model(
|
|||||||
raise NotImplementedError("Eetq quantization is not supported for AutoModel")
|
raise NotImplementedError("Eetq quantization is not supported for AutoModel")
|
||||||
elif quantize == "exl2":
|
elif quantize == "exl2":
|
||||||
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
|
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
|
||||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
|
||||||
return CausalLM.fallback(
|
# Fast transformers if available
|
||||||
|
transformers_model_class = getattr(
|
||||||
|
transformers,
|
||||||
|
modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(model_type, ""),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
FLASH_TRANSFORMERS_BACKEND
|
||||||
|
and transformers_model_class is not None
|
||||||
|
and transformers_model_class._supports_flex_attn
|
||||||
|
):
|
||||||
|
return TransformersFlashCausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -1337,6 +1485,10 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if sharded:
|
||||||
|
raise NotImplementedError("sharded is not supported for AutoModel")
|
||||||
|
|
||||||
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
||||||
return Seq2SeqLM.fallback(
|
return Seq2SeqLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
@ -1449,6 +1601,9 @@ def get_model_with_lora_adapters(
|
|||||||
"up_proj",
|
"up_proj",
|
||||||
"down_proj",
|
"down_proj",
|
||||||
"qkv_proj",
|
"qkv_proj",
|
||||||
|
# add c_* layers used in starcoder2
|
||||||
|
"c_proj",
|
||||||
|
"c_fc",
|
||||||
]
|
]
|
||||||
|
|
||||||
for layer_name in adapter_layers:
|
for layer_name in adapter_layers:
|
||||||
|
@ -0,0 +1,676 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
FastLinear,
|
||||||
|
SpeculativeHead,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
get_linear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
Seqlen,
|
||||||
|
attention,
|
||||||
|
paged_attention,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
|
||||||
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.weights import Weights
|
||||||
|
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
try:
|
||||||
|
import vllm._custom_ops as ops
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(f"Could not load `vllm._custom_ops`. Full error: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3Config(PretrainedConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=102400,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=11008,
|
||||||
|
moe_intermediate_size=1407,
|
||||||
|
num_hidden_layers=30,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=32,
|
||||||
|
n_shared_experts=2,
|
||||||
|
n_routed_experts=160,
|
||||||
|
ep_size=1,
|
||||||
|
routed_scaling_factor=1.0,
|
||||||
|
kv_lora_rank=512,
|
||||||
|
q_lora_rank=1536,
|
||||||
|
qk_rope_head_dim=64,
|
||||||
|
v_head_dim=128,
|
||||||
|
qk_nope_head_dim=128,
|
||||||
|
topk_method="gready",
|
||||||
|
n_group=8,
|
||||||
|
topk_group=3,
|
||||||
|
num_experts_per_tok=6,
|
||||||
|
moe_layer_freq=1,
|
||||||
|
first_k_dense_replace=0,
|
||||||
|
norm_topk_prob=False,
|
||||||
|
scoring_func="softmax",
|
||||||
|
aux_loss_alpha=0.001,
|
||||||
|
seq_aux=True,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=100000,
|
||||||
|
eos_token_id=100001,
|
||||||
|
pretraining_tp=1,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.moe_intermediate_size = moe_intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.n_shared_experts = n_shared_experts
|
||||||
|
self.n_routed_experts = n_routed_experts
|
||||||
|
self.ep_size = ep_size
|
||||||
|
self.routed_scaling_factor = routed_scaling_factor
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.q_lora_rank = q_lora_rank
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.v_head_dim = v_head_dim
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.topk_method = topk_method
|
||||||
|
self.n_group = n_group
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
|
self.moe_layer_freq = moe_layer_freq
|
||||||
|
self.first_k_dense_replace = first_k_dense_replace
|
||||||
|
self.norm_topk_prob = norm_topk_prob
|
||||||
|
self.scoring_func = scoring_func
|
||||||
|
self.aux_loss_alpha = aux_loss_alpha
|
||||||
|
self.seq_aux = seq_aux
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.pretraining_tp = pretraining_tp
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
|
||||||
|
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
|
||||||
|
if tie_word_embeddings:
|
||||||
|
raise ValueError(
|
||||||
|
"tie_word_embeddings is not supported for Deepseek V2 models."
|
||||||
|
)
|
||||||
|
|
||||||
|
if ep_size != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3Attention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix: str,
|
||||||
|
config,
|
||||||
|
weights: Weights,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.kv_lora_rank = config.kv_lora_rank
|
||||||
|
self.q_lora_rank = config.q_lora_rank
|
||||||
|
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||||
|
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||||
|
self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||||
|
self.value_head_size = config.v_head_dim
|
||||||
|
self.head_pad_size = max(self.head_size, self.value_head_size)
|
||||||
|
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
config=config,
|
||||||
|
dim=self.qk_rope_head_dim,
|
||||||
|
base=config.rope_theta,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
mscale = get_mscale(
|
||||||
|
self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
|
||||||
|
)
|
||||||
|
self.softmax_scale = self.head_size**-0.5 * mscale * mscale
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.q_lora_rank is None:
|
||||||
|
self.q_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.q_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.q_a_proj = get_linear(
|
||||||
|
weight=weights.get_weights(f"{prefix}.q_a_proj"),
|
||||||
|
bias=(
|
||||||
|
weights.get_tensor(f"{prefix}.q_a_proj.bias")
|
||||||
|
if config.attention_bias
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.q_a_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.q_a_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.q_b_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.q_b_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_a_proj_with_mqa = get_linear(
|
||||||
|
weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"),
|
||||||
|
bias=(
|
||||||
|
weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias")
|
||||||
|
if config.attention_bias
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
|
self.kv_a_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_b_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.kv_b_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
max_s: int,
|
||||||
|
):
|
||||||
|
if self.q_lora_rank is None:
|
||||||
|
query = self.q_proj(hidden_states)
|
||||||
|
else:
|
||||||
|
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
|
||||||
|
_, query_pe = torch.split(
|
||||||
|
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||||
|
compressed_kv, key_pe = torch.split(
|
||||||
|
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
|
||||||
|
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
|
||||||
|
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
|
||||||
|
)
|
||||||
|
|
||||||
|
key_nope, value = torch.split(
|
||||||
|
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size, heads, head_dim = query_pe.shape
|
||||||
|
query_pe = (
|
||||||
|
query_pe.view(batch_size, heads, head_dim // 2, 2)
|
||||||
|
.transpose(2, 3)
|
||||||
|
.reshape(batch_size, heads, head_dim)
|
||||||
|
)
|
||||||
|
batch_size, heads, head_dim = key_pe.shape
|
||||||
|
key_pe = (
|
||||||
|
key_pe.view(batch_size, heads, head_dim // 2, 2)
|
||||||
|
.transpose(2, 3)
|
||||||
|
.reshape(batch_size, heads, head_dim)
|
||||||
|
)
|
||||||
|
self.rotary_emb(query_pe, key_pe, cos, sin)
|
||||||
|
|
||||||
|
query[..., self.qk_nope_head_dim :] = query_pe
|
||||||
|
key = torch.empty_like(query)
|
||||||
|
key[..., : self.qk_nope_head_dim] = key_nope
|
||||||
|
key[..., self.qk_nope_head_dim :] = key_pe
|
||||||
|
|
||||||
|
# We need to pad the heads because Flash Attention does not support
|
||||||
|
# qk and v with different head sizes.
|
||||||
|
query = torch.nn.functional.pad(
|
||||||
|
query, (0, self.head_pad_size - self.head_size), value=0
|
||||||
|
)
|
||||||
|
key = torch.nn.functional.pad(
|
||||||
|
key, (0, self.head_pad_size - self.head_size), value=0
|
||||||
|
)
|
||||||
|
value = torch.nn.functional.pad(
|
||||||
|
value, (0, self.head_pad_size - self.value_head_size), value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_cache.store(
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
# flash attention
|
||||||
|
attn_output = attention(
|
||||||
|
query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
seqlen=seqlen,
|
||||||
|
block_tables=block_tables,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
attn_output = paged_attention(
|
||||||
|
query,
|
||||||
|
kv_cache,
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
seqlen,
|
||||||
|
max_s,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove padding.
|
||||||
|
attn_output = attn_output[..., : self.value_head_size]
|
||||||
|
|
||||||
|
return self.o_proj(
|
||||||
|
attn_output.reshape(-1, self.num_heads * self.value_head_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3MLP(nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights, intermediate_size: int):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_act = config.hidden_act
|
||||||
|
if self.hidden_act != "silu":
|
||||||
|
# Bail out because MoE only supports silu.
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Currently only `silu` is supported as an activation for Deepseek V2."
|
||||||
|
)
|
||||||
|
self.act = ACT2FN[self.hidden_act]
|
||||||
|
|
||||||
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
|
weights=weights,
|
||||||
|
dim=0,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.intermediate_size = intermediate_size // weights.process_group.size()
|
||||||
|
|
||||||
|
# TODO: This is a hotfix to be removed & properly refactored.
|
||||||
|
self.quantize = config.quantize
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
|
||||||
|
if (
|
||||||
|
SYSTEM == "rocm"
|
||||||
|
and self.hidden_act == "silu"
|
||||||
|
and hidden_states.dtype == torch.float16
|
||||||
|
and hidden_states.shape[0] == 1
|
||||||
|
and not self.quantize
|
||||||
|
):
|
||||||
|
out = torch.empty(
|
||||||
|
hidden_states.shape[0],
|
||||||
|
self.intermediate_size,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
ops.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
||||||
|
return self.down_proj(out, reduce=reduce)
|
||||||
|
else:
|
||||||
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
|
return self.down_proj(
|
||||||
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3MoE(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix,
|
||||||
|
config: DeepseekV3Config,
|
||||||
|
moe_layer_cls: Type[MoELayer],
|
||||||
|
weights,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_dim = config.hidden_size
|
||||||
|
self.moe_intermediate_size = (
|
||||||
|
config.moe_intermediate_size // weights.process_group.size()
|
||||||
|
)
|
||||||
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
|
|
||||||
|
# Gating
|
||||||
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
|
if config.topk_method == "noaux_tc":
|
||||||
|
self.gate.e_score_correction_bias = torch.zeros(
|
||||||
|
config.n_routed_experts, device=weights.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.gate.e_score_correction_bias = None
|
||||||
|
|
||||||
|
self.moe_layer = moe_layer_cls(
|
||||||
|
prefix=f"{prefix}.experts",
|
||||||
|
n_experts=config.n_routed_experts,
|
||||||
|
n_expert_group=config.n_group,
|
||||||
|
renormalize=config.norm_topk_prob,
|
||||||
|
topk=config.num_experts_per_tok,
|
||||||
|
topk_group=config.topk_group,
|
||||||
|
weights=weights,
|
||||||
|
scoring_func=config.scoring_func,
|
||||||
|
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||||
|
)
|
||||||
|
assert isinstance(self.moe_layer, MoELayer)
|
||||||
|
|
||||||
|
if config.n_shared_experts is not None:
|
||||||
|
self.shared_experts = DeepseekV3MLP(
|
||||||
|
prefix=f"{prefix}.shared_experts",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
intermediate_size=config.moe_intermediate_size
|
||||||
|
* config.n_shared_experts,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.shared_experts = None
|
||||||
|
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.shared_experts is not None:
|
||||||
|
shared_output = self.shared_experts(x, reduce=False)
|
||||||
|
else:
|
||||||
|
shared_output = None
|
||||||
|
|
||||||
|
router_logits = self.gate(x)
|
||||||
|
|
||||||
|
out = self.moe_layer(x, gating_output=router_logits)
|
||||||
|
|
||||||
|
if shared_output is not None:
|
||||||
|
out = out + shared_output
|
||||||
|
|
||||||
|
# Reduce sum
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
|
||||||
|
return out.view(*x.shape)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3Layer(nn.Module):
|
||||||
|
def __init__(self, prefix, layer_id, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
prefix = f"{prefix}.layers.{layer_id}"
|
||||||
|
|
||||||
|
self.self_attn = DeepseekV3Attention(
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
config.n_routed_experts is not None
|
||||||
|
and layer_id >= config.first_k_dense_replace
|
||||||
|
and layer_id % config.moe_layer_freq == 0
|
||||||
|
):
|
||||||
|
moe_layer_cls = (
|
||||||
|
SparseMoELayer
|
||||||
|
if SparseMoELayer.is_supported(weights)
|
||||||
|
else DenseMoELayer
|
||||||
|
)
|
||||||
|
self.mlp = DeepseekV3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
|
||||||
|
else:
|
||||||
|
self.mlp = DeepseekV3MLP(
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: torch.Tensor,
|
||||||
|
kv_cache,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
max_s: int,
|
||||||
|
):
|
||||||
|
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
attn_output = self.self_attn(
|
||||||
|
normed_hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
max_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
# faster post attention rms norm
|
||||||
|
normed_attn_res_output, residual = self.post_attention_layernorm(
|
||||||
|
attn_output, residual
|
||||||
|
)
|
||||||
|
|
||||||
|
output = self.mlp(normed_attn_res_output)
|
||||||
|
|
||||||
|
return output, residual
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3Model(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights: Weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DeepseekV3Layer(
|
||||||
|
prefix,
|
||||||
|
layer_id,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
max_s: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# Get rotary cos and sin for this forward
|
||||||
|
# Avoid to index in each layer
|
||||||
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
|
position_ids, max_s, hidden_states.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache[i],
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
max_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashDeepseekV3ForCausalLM(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights: Weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = DeepseekV3Model(
|
||||||
|
"model" if not prefix else f"{prefix}.model", config, weights
|
||||||
|
)
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
config,
|
||||||
|
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
max_s,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
return logits, speculative_logits
|
@ -642,9 +642,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
embedding_multiplier = getattr(config, "embedding_multiplier", None)
|
embedding_multiplier = getattr(config, "embedding_multiplier", None)
|
||||||
if embedding_multiplier is not None:
|
if embedding_multiplier is not None:
|
||||||
self.embed_tokens.weight.data *= embedding_multiplier
|
self.embed_tokens.weight.data *= embedding_multiplier
|
||||||
|
prefix = suffix if not prefix or name != "model" else f"{prefix}.{suffix}"
|
||||||
prefix = "lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}"
|
|
||||||
|
|
||||||
with no_fp8(weights):
|
with no_fp8(weights):
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
|
@ -32,6 +32,8 @@ from text_generation_server.layers.attention import (
|
|||||||
Seqlen,
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelMultiAdapterLinear,
|
||||||
|
TensorParallelAdapterRowLinear,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
@ -109,17 +111,31 @@ class Starcoder2Config(PretrainedConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights, layer_id):
|
||||||
|
prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
|
||||||
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
sizes = [
|
||||||
|
head_size * config.num_attention_heads,
|
||||||
|
head_size * config.num_key_value_heads,
|
||||||
|
head_size * config.num_key_value_heads,
|
||||||
|
]
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
return _load_gqa(config, prefix, weights)
|
base_layer = _load_gqa(config, prefix, weights)
|
||||||
else:
|
else:
|
||||||
return TensorParallelColumnLinear.load_multi(
|
base_layer = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefixes=prefixes,
|
||||||
dim=0,
|
dim=0,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=config.use_bias,
|
bias=config.use_bias,
|
||||||
)
|
)
|
||||||
|
return TensorParallelMultiAdapterLinear.load(
|
||||||
|
base_layer=base_layer,
|
||||||
|
layer_id=layer_id,
|
||||||
|
layer_names=prefixes,
|
||||||
|
sizes=sizes,
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _load_gqa(config, prefix: str, weights):
|
def _load_gqa(config, prefix: str, weights):
|
||||||
@ -157,6 +173,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
class Starcoder2Attention(torch.nn.Module):
|
class Starcoder2Attention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
index: int,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
@ -188,15 +205,23 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
config.num_key_value_heads // weights.process_group.size()
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights, index)
|
||||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.o_proj",
|
prefix=f"{prefix}.o_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=config.use_bias,
|
bias=getattr(config, "use_bias", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
o_proj,
|
||||||
|
index,
|
||||||
|
"o_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
self.kv_head_mapping = torch.arange(
|
self.kv_head_mapping = torch.arange(
|
||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
@ -214,8 +239,9 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
adapter_data,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
[
|
[
|
||||||
self.head_size * self.num_heads,
|
self.head_size * self.num_heads,
|
||||||
@ -267,11 +293,13 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
kv_scales=self.kv_scales,
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(
|
||||||
|
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Starcoder2MLP(nn.Module):
|
class Starcoder2MLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights, index):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
act = config.hidden_act
|
act = config.hidden_act
|
||||||
self.act = (
|
self.act = (
|
||||||
@ -285,27 +313,42 @@ class Starcoder2MLP(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
self.c_fc = TensorParallelColumnLinear.load(
|
c_fc = TensorParallelColumnLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.c_fc",
|
prefix=f"{prefix}.c_fc",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=config.use_bias,
|
bias=config.use_bias,
|
||||||
)
|
)
|
||||||
self.c_proj = TensorParallelRowLinear.load(
|
c_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.c_proj",
|
prefix=f"{prefix}.c_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=config.use_bias,
|
bias=config.use_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
self.c_fc = TensorParallelMultiAdapterLinear.load(
|
||||||
hidden_states = self.c_fc(hidden_states)
|
c_fc,
|
||||||
|
layer_id=index,
|
||||||
|
layer_names=[f"{prefix}.c_fc"],
|
||||||
|
sizes=[config.intermediate_size, config.intermediate_size],
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.c_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
c_proj,
|
||||||
|
index,
|
||||||
|
"c_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, adapter_data):
|
||||||
|
hidden_states = self.c_fc(hidden_states, adapter_data)
|
||||||
hidden_states = self.act(hidden_states)
|
hidden_states = self.act(hidden_states)
|
||||||
return self.c_proj(hidden_states)
|
return self.c_proj(hidden_states, adapter_data)
|
||||||
|
|
||||||
|
|
||||||
class Starcoder2GatedMLP(nn.Module):
|
class Starcoder2GatedMLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, index, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
act = config.hidden_act
|
act = config.hidden_act
|
||||||
self.act = (
|
self.act = (
|
||||||
@ -319,27 +362,47 @@ class Starcoder2GatedMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
|
||||||
|
sizes = [
|
||||||
|
config.intermediate_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
]
|
||||||
|
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
prefixes=prefixes,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
dim=0,
|
dim=0,
|
||||||
bias=config.use_bias,
|
bias=config.use_bias,
|
||||||
)
|
)
|
||||||
self.down_proj = TensorParallelRowLinear.load(
|
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||||
|
gate_up_proj,
|
||||||
|
index,
|
||||||
|
layer_names=prefixes,
|
||||||
|
sizes=sizes,
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
down_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.down_proj",
|
prefix=f"{prefix}.down_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=config.use_bias,
|
bias=config.use_bias,
|
||||||
)
|
)
|
||||||
|
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
down_proj,
|
||||||
|
index,
|
||||||
|
"down_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
self.intermediate_size = (
|
self.intermediate_size = (
|
||||||
config.intermediate_size // weights.process_group.size()
|
config.intermediate_size // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states, adapter_data):
|
||||||
gate_up_states = self.gate_up_proj(hidden_states)
|
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
return self.down_proj(
|
||||||
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
STARCODER2_NORMALIZATION_CLASSES = {
|
STARCODER2_NORMALIZATION_CLASSES = {
|
||||||
@ -358,11 +421,11 @@ class Starcoder2Layer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"model.layers.{layer_id}"
|
prefix = f"model.layers.{layer_id}"
|
||||||
self.self_attn = Starcoder2Attention(
|
self.self_attn = Starcoder2Attention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn", config=config, weights=weights, index=layer_id
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
|
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
|
||||||
prefix=f"{prefix}.mlp", config=config, weights=weights
|
prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id
|
||||||
)
|
)
|
||||||
|
|
||||||
self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
|
self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
|
||||||
@ -389,6 +452,7 @@ class Starcoder2Layer(nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
adapter_data,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
@ -404,6 +468,7 @@ class Starcoder2Layer(nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
adapter_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
@ -411,7 +476,7 @@ class Starcoder2Layer(nn.Module):
|
|||||||
attn_output, res
|
attn_output, res
|
||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.mlp(normed_attn_res_output)
|
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||||
|
|
||||||
return mlp_output, attn_res
|
return mlp_output, attn_res
|
||||||
|
|
||||||
@ -458,6 +523,7 @@ class Starcoder2Model(torch.nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
true_max_s: int,
|
true_max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
adapter_data,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
@ -481,6 +547,7 @@ class Starcoder2Model(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
adapter_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
@ -552,6 +619,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
true_max_s,
|
true_max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
adapter_data,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -1595,7 +1595,9 @@ class FlashCausalLM(Model):
|
|||||||
if max_total_tokens is None:
|
if max_total_tokens is None:
|
||||||
if get_support_chunking():
|
if get_support_chunking():
|
||||||
model_max_length = self.tokenizer.model_max_length
|
model_max_length = self.tokenizer.model_max_length
|
||||||
max_position_embeddings = self.config.max_position_embeddings
|
max_position_embeddings = getattr(
|
||||||
|
self.config, "max_position_embeddings", model_max_length
|
||||||
|
)
|
||||||
max_total_tokens = min(
|
max_total_tokens = min(
|
||||||
num_blocks * BLOCK_SIZE, model_max_length, max_position_embeddings
|
num_blocks * BLOCK_SIZE, model_max_length, max_position_embeddings
|
||||||
)
|
)
|
||||||
|
@ -14,26 +14,33 @@ PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
|
|||||||
}
|
}
|
||||||
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "1").lower() in {"1", "true"}
|
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "1").lower() in {"1", "true"}
|
||||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||||
_expected = {"paged", "flashdecoding", "flashinfer"}
|
_expected = {"paged", "flashdecoding", "flashdecoding-ipex", "flashinfer"}
|
||||||
assert (
|
assert (
|
||||||
ATTENTION in _expected
|
ATTENTION in _expected
|
||||||
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
||||||
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
||||||
|
|
||||||
if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
|
if PREFIX_CACHING and ATTENTION not in {
|
||||||
|
"flashinfer",
|
||||||
|
"flashdecoding",
|
||||||
|
"flashdecoding-ipex",
|
||||||
|
}:
|
||||||
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95"))
|
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.93"))
|
||||||
assert TGI_WIGGLE_ROOM > 0
|
assert TGI_WIGGLE_ROOM > 0
|
||||||
assert TGI_WIGGLE_ROOM < 1
|
assert TGI_WIGGLE_ROOM < 1
|
||||||
|
|
||||||
|
|
||||||
# This is overridden by the cli
|
# This is overridden by the cli
|
||||||
BLOCK_SIZE: int
|
BLOCK_SIZE: int
|
||||||
if ATTENTION == "flashdecoding":
|
if ATTENTION == "flashdecoding":
|
||||||
BLOCK_SIZE = 256
|
BLOCK_SIZE = 256
|
||||||
elif ATTENTION == "flashinfer":
|
elif ATTENTION == "flashinfer":
|
||||||
BLOCK_SIZE = 1
|
BLOCK_SIZE = 1
|
||||||
|
elif ATTENTION == "flashdecoding-ipex":
|
||||||
|
BLOCK_SIZE = 64
|
||||||
else:
|
else:
|
||||||
BLOCK_SIZE = 16
|
BLOCK_SIZE = 16
|
||||||
|
|
||||||
|
@ -79,10 +79,13 @@ class Model(ABC):
|
|||||||
"Prefill chunking will be turned off",
|
"Prefill chunking will be turned off",
|
||||||
)
|
)
|
||||||
support_chunking = False
|
support_chunking = False
|
||||||
if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking:
|
if (
|
||||||
|
ATTENTION not in ["flashinfer", "flashdecoding", "flashdecoding-ipex"]
|
||||||
|
and support_chunking
|
||||||
|
):
|
||||||
log_master(
|
log_master(
|
||||||
logger.warning,
|
logger.warning,
|
||||||
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.",
|
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` or `flashdecoding-ipex` attention types.",
|
||||||
)
|
)
|
||||||
support_chunking = False
|
support_chunking = False
|
||||||
|
|
||||||
|
@ -0,0 +1,270 @@
|
|||||||
|
import math
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from opentelemetry import trace
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
import transformers.modeling_utils
|
||||||
|
|
||||||
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
|
from text_generation_server.utils import initialize_torch_distributed
|
||||||
|
|
||||||
|
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
|
||||||
|
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
|
||||||
|
from text_generation_server.models.globals import ATTENTION
|
||||||
|
|
||||||
|
|
||||||
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def tgi_flash_attention_forward(
|
||||||
|
module,
|
||||||
|
query_states: torch.Tensor,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor], # This is a positional arg in Transformers
|
||||||
|
kv_cache: List[KVCache],
|
||||||
|
kv_head_mapping: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
seqlen: Seqlen,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
kv_scales: KVScales,
|
||||||
|
softmax_scale: Optional[float] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
softcap: Optional[float] = None,
|
||||||
|
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
||||||
|
):
|
||||||
|
kv_cache = kv_cache[module.layer_idx]
|
||||||
|
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
||||||
|
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
||||||
|
value_states = value_states.transpose(1, 2).squeeze(dim=0)
|
||||||
|
|
||||||
|
# Take care of updating the cache in-place
|
||||||
|
kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales)
|
||||||
|
|
||||||
|
_, num_heads, head_dim = query_states.shape
|
||||||
|
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
|
||||||
|
sliding_window = -1 if sliding_window is None else sliding_window
|
||||||
|
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
attn_output = attention(
|
||||||
|
query=query_states,
|
||||||
|
key=key_states,
|
||||||
|
value=value_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=kv_scales,
|
||||||
|
seqlen=seqlen,
|
||||||
|
block_tables=block_tables,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
window_size_left=sliding_window,
|
||||||
|
softcap=softcap,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_output = paged_attention(
|
||||||
|
query_states,
|
||||||
|
kv_cache,
|
||||||
|
kv_head_mapping,
|
||||||
|
softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
seqlen,
|
||||||
|
max_s,
|
||||||
|
kv_scales=kv_scales,
|
||||||
|
softcap=softcap,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(-1, num_heads * head_dim)
|
||||||
|
|
||||||
|
return attn_output, None
|
||||||
|
|
||||||
|
|
||||||
|
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersFlashCausalLM(FlashCausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
speculator: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
default_dtype=torch.float16,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
tokenizer_class=AutoTokenizer,
|
||||||
|
kv_cache_dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
self.quantize = quantize
|
||||||
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
|
|
||||||
|
if speculator:
|
||||||
|
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
dtype = default_dtype if dtype is None else dtype
|
||||||
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device("xpu")
|
||||||
|
dtype = default_dtype if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Flash `Transformers` modeling backend is not available on cpu."
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
attn_implementation="tgi",
|
||||||
|
device_map=device if world_size == 1 else None,
|
||||||
|
tp_plan="auto" if world_size > 1 else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if tokenizer.pad_token_id is None:
|
||||||
|
if model.config.pad_token_id is not None:
|
||||||
|
tokenizer.pad_token_id = model.config.pad_token_id
|
||||||
|
elif model.config.eos_token_id is not None and isinstance(
|
||||||
|
model.config.eos_token_id, int
|
||||||
|
):
|
||||||
|
tokenizer.pad_token_id = model.config.eos_token_id
|
||||||
|
elif tokenizer.eos_token_id is not None:
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
else:
|
||||||
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
|
|
||||||
|
self.num_layers = model.config.num_hidden_layers
|
||||||
|
self.num_heads = model.config.num_attention_heads // self.process_group.size()
|
||||||
|
self.num_kv_heads = model.config.num_key_value_heads
|
||||||
|
self.num_kv_heads = (
|
||||||
|
self.num_kv_heads // self.process_group.size()
|
||||||
|
if self.num_kv_heads > 1
|
||||||
|
else self.num_kv_heads
|
||||||
|
)
|
||||||
|
self.head_size = model.config.hidden_size // model.config.num_attention_heads
|
||||||
|
|
||||||
|
self.cuda_graphs = {}
|
||||||
|
self.kv_cache = []
|
||||||
|
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
|
||||||
|
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
|
create_prefill_state,
|
||||||
|
create_decode_state,
|
||||||
|
create_prefill_with_paged_kv_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.prefill_state = create_prefill_state(device=device)
|
||||||
|
self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decode_state = create_decode_state(
|
||||||
|
device=device,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_groups = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
|
# Those will never change and will be used in the forwards
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_kv_heads, dtype=torch.int32, device=device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
# This means no scale
|
||||||
|
self.kv_scales = KVScales(
|
||||||
|
torch.tensor(1.0, device=device),
|
||||||
|
torch.tensor(1.0, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
# Skip FlashCausalLM init.
|
||||||
|
super(FlashCausalLM, self).__init__(
|
||||||
|
model_id=model_id,
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=False,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code
|
||||||
|
# We first copy the original model.forward because we still need it in the monkey patch
|
||||||
|
self.model.original_forward = self.model.forward
|
||||||
|
self.model.forward = self._model_forward
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fallback(
|
||||||
|
cls,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
speculator: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
return cls(
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _model_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[KVCache],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
max_s: int,
|
||||||
|
lm_head_indices: Optional[torch.Tensor],
|
||||||
|
prefill_cache_indices=None, # not used, but passed to match original signature
|
||||||
|
adapter_data=None, # not supported, but passed to match original signature
|
||||||
|
):
|
||||||
|
hidden_states = self.model.model.forward(
|
||||||
|
input_ids=input_ids.unsqueeze(0), # expand dim to fit Transformers
|
||||||
|
position_ids=position_ids.unsqueeze(0), # expand dim to fit Transformers
|
||||||
|
past_key_values=None, # we use self.kv_cache instead of transformers cache object
|
||||||
|
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
||||||
|
return_dict=True,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
max_s=max_s,
|
||||||
|
kv_head_mapping=self.kv_head_mapping,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)[0].squeeze(dim=0)
|
||||||
|
|
||||||
|
# And compute logits from the lm_head, slicing correctly the indices
|
||||||
|
# NOTE: some logits post-processing (e.g. in gemma2) may be absent here with the split of the modules
|
||||||
|
# To update with full Transformers support asap
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits = self.model.lm_head(hidden_states)
|
||||||
|
|
||||||
|
# For Granite while next transformers version is released and we can use `lm_head_indices` natively
|
||||||
|
if hasattr(self.model.config, "logits_scaling"):
|
||||||
|
logits = logits / self.model.config.logits_scaling
|
||||||
|
# For Cohere for similar reasons
|
||||||
|
elif hasattr(self.model, "logit_scale"):
|
||||||
|
logits = logits * self.model.logit_scale
|
||||||
|
|
||||||
|
return logits, None
|
@ -68,9 +68,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
self.quantize = model.quantize
|
self.quantize = model.quantize
|
||||||
self.server_urls = server_urls
|
self.server_urls = server_urls
|
||||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||||
if model.device.type == "cuda":
|
# if model.device.type == "cuda":
|
||||||
# Force inference mode for the lifetime of TextGenerationService
|
# # Force inference mode for the lifetime of TextGenerationService
|
||||||
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
# self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
||||||
|
|
||||||
async def Info(self, request, context):
|
async def Info(self, request, context):
|
||||||
return self.model.info
|
return self.model.info
|
||||||
|
@ -281,6 +281,12 @@ def get_mlp_weights(i, layer):
|
|||||||
if hasattr(mlp, "up_proj"):
|
if hasattr(mlp, "up_proj"):
|
||||||
weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj)
|
weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj)
|
||||||
|
|
||||||
|
if hasattr(mlp, "c_fc"):
|
||||||
|
weights[(i, "c_fc")] = (f"model.layers.{i}.mlp.c_fc", mlp.c_fc)
|
||||||
|
|
||||||
|
if hasattr(mlp, "c_proj"):
|
||||||
|
weights[(i, "c_proj")] = (f"model.layers.{i}.mlp.c_proj", mlp.c_proj)
|
||||||
|
|
||||||
if hasattr(mlp, "down_proj"):
|
if hasattr(mlp, "down_proj"):
|
||||||
weights[(i, "down_proj")] = (
|
weights[(i, "down_proj")] = (
|
||||||
f"model.layers.{i}.mlp.down_proj",
|
f"model.layers.{i}.mlp.down_proj",
|
||||||
|
@ -81,12 +81,14 @@ def initialize_torch_distributed():
|
|||||||
pg_options=options,
|
pg_options=options,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
device = torch.device(f"cuda:{RANK}")
|
||||||
torch.distributed.init_process_group(
|
torch.distributed.init_process_group(
|
||||||
backend=backend,
|
backend=backend,
|
||||||
world_size=WORLD_SIZE,
|
world_size=WORLD_SIZE,
|
||||||
rank=RANK,
|
rank=RANK,
|
||||||
timeout=timedelta(seconds=120),
|
timeout=timedelta(seconds=120),
|
||||||
pg_options=options,
|
pg_options=options,
|
||||||
|
device_id=device,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("torch.distributed is already initialized.")
|
logger.warning("torch.distributed is already initialized.")
|
||||||
|
@ -5,13 +5,12 @@ import torch
|
|||||||
from typing import List, Optional, DefaultDict
|
from typing import List, Optional, DefaultDict
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Dict, Union
|
from typing import Dict
|
||||||
from text_generation_server.pb.generate_pb2 import GrammarType
|
from text_generation_server.pb.generate_pb2 import GrammarType
|
||||||
|
|
||||||
from outlines.fsm.guide import RegexGuide
|
from outlines.fsm.guide import RegexGuide
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
LogitsWarper,
|
|
||||||
LogitsProcessor,
|
LogitsProcessor,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
@ -219,7 +218,7 @@ class HeterogeneousTemperatureLogitsWarper:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
class HeterogeneousTopPLogitsWarper(LogitsProcessor):
|
||||||
"""
|
"""
|
||||||
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
||||||
This version allows for a separate value for each sample and runs inplace when possible.
|
This version allows for a separate value for each sample and runs inplace when possible.
|
||||||
@ -278,7 +277,7 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
class HeterogeneousTopKLogitsWarper(LogitsProcessor):
|
||||||
r"""
|
r"""
|
||||||
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
||||||
This version allows for a separate value for each sample and runs inplace when possible.
|
This version allows for a separate value for each sample and runs inplace when possible.
|
||||||
@ -359,7 +358,7 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
class HeterogeneousTypicalLogitsWarper(LogitsProcessor):
|
||||||
r"""
|
r"""
|
||||||
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
|
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
|
||||||
Generation](https://arxiv.org/abs/2202.00666) for more information.
|
Generation](https://arxiv.org/abs/2202.00666) for more information.
|
||||||
@ -453,13 +452,13 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
|
|||||||
r"""
|
r"""
|
||||||
A wrapper for logit warpers or processors without heterogeneous parameter support.
|
A wrapper for logit warpers or processors without heterogeneous parameter support.
|
||||||
Args:
|
Args:
|
||||||
processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`):
|
processors (`Dict[int, LogitsProcessor]`):
|
||||||
A mapping of sample indices to logit warpers or processors, to be run sequentially.
|
A mapping of sample indices to logit warpers or processors, to be run sequentially.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
processors: Dict[int, Union[LogitsProcessor, LogitsWarper]],
|
processors: Dict[int, LogitsProcessor],
|
||||||
):
|
):
|
||||||
self.processors = processors
|
self.processors = processors
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from text_generation_server.layers.marlin.gptq import can_use_gptq_marlin
|
from text_generation_server.layers.marlin.gptq import can_use_gptq_marlin
|
||||||
@ -20,6 +20,7 @@ class _QuantizerConfig:
|
|||||||
groupsize: int
|
groupsize: int
|
||||||
quant_method: str
|
quant_method: str
|
||||||
sym: bool
|
sym: bool
|
||||||
|
weight_block_size: Optional[List[int]]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -49,16 +50,17 @@ def _get_quantizer_config(model_id, revision):
|
|||||||
checkpoint_format = None
|
checkpoint_format = None
|
||||||
sym = False
|
sym = False
|
||||||
desc_act = False
|
desc_act = False
|
||||||
|
weight_block_size = None
|
||||||
|
|
||||||
filename = "config.json"
|
filename = "config.json"
|
||||||
try:
|
try:
|
||||||
data = _get_config_json(model_id, revision, filename)
|
data = _get_config_json(model_id, revision, filename)
|
||||||
|
|
||||||
# FP8 config
|
# FP8 config
|
||||||
if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
|
if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
|
||||||
return _FP8QuantizerConfig(
|
return _FP8QuantizerConfig(
|
||||||
activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
|
activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
|
||||||
)
|
)
|
||||||
|
weight_block_size = data["quantization_config"].get("weight_block_size", None)
|
||||||
|
|
||||||
if "zero_point" in data["quantization_config"]:
|
if "zero_point" in data["quantization_config"]:
|
||||||
sym = not data["quantization_config"]["zero_point"]
|
sym = not data["quantization_config"]["zero_point"]
|
||||||
@ -107,6 +109,7 @@ def _get_quantizer_config(model_id, revision):
|
|||||||
checkpoint_format=checkpoint_format,
|
checkpoint_format=checkpoint_format,
|
||||||
sym=sym,
|
sym=sym,
|
||||||
desc_act=desc_act,
|
desc_act=desc_act,
|
||||||
|
weight_block_size=weight_block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -196,9 +199,14 @@ def get_loader(
|
|||||||
# Since the default for the quantize config is _QuantizerConfig,
|
# Since the default for the quantize config is _QuantizerConfig,
|
||||||
# we need to add this check to not get an attribute error
|
# we need to add this check to not get an attribute error
|
||||||
activation_scale_ub = None
|
activation_scale_ub = None
|
||||||
|
weight_block_size = quantizer_config.weight_block_size
|
||||||
if isinstance(quantizer_config, _FP8QuantizerConfig):
|
if isinstance(quantizer_config, _FP8QuantizerConfig):
|
||||||
activation_scale_ub = quantizer_config.activation_scale_ub
|
activation_scale_ub = quantizer_config.activation_scale_ub
|
||||||
|
|
||||||
return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8")
|
return HybridFP8UnquantLoader(
|
||||||
|
activation_scale_ub,
|
||||||
|
to_fp8=quantize == "fp8",
|
||||||
|
weight_block_size=weight_block_size,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown quantization method: {quantize}")
|
raise ValueError(f"Unknown quantization method: {quantize}")
|
||||||
|
3268
server/uv.lock
Normal file
3268
server/uv.lock
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user