Compare commits

..

1 Commits
main ... v3.1.1

Author SHA1 Message Date
Nicolas Patry
c34bd9d8d9
3.1.1 Release. 2025-03-04 18:11:30 +01:00
269 changed files with 28699 additions and 27266 deletions

View File

@ -45,7 +45,7 @@ jobs:
uses: actions/github-script@v7
with:
script: |
core.exportVariable('ACTIONS_RESULTS_URL', process.env.ACTIONS_RESULTS_URL || '');
core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || '');
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
- name: Extract TensorRT-LLM version
@ -124,15 +124,6 @@ jobs:
export extra_pytest="--neuron"
export target=""
;;
gaudi)
export dockerfile="Dockerfile_gaudi"
export label_extension="-gaudi"
export docker_volume="/mnt/cache"
export docker_devices=""
export runs_on="ubuntu-latest"
export platform=""
export extra_pytest=""
export target=""
esac
echo $dockerfile
echo "Dockerfile=${dockerfile}"
@ -223,7 +214,7 @@ jobs:
PLATFORM=${{ env.PLATFORM }}
build_type=${{ env.BUILD_TYPE }}
sccache_gha_enabled=on
actions_results_url=${{ env.ACTIONS_RESULTS_URL }}
actions_cache_url=${{ env.ACTIONS_CACHE_URL }}
actions_runtime_token=${{ env.ACTIONS_RUNTIME_TOKEN }}
target: ${{ env.TARGET }}
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
@ -233,12 +224,7 @@ jobs:
- name: Final
id: final
run: |
if [ "${{ github.event_name }}" = "pull_request" ]; then
echo "docker_image=docker.io/huggingface/text-generation-inference-ci:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
else
echo "docker_image=ghcr.io/huggingface/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
fi
echo "docker_image=docker.io/huggingface/text-generation-inference-ci:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT"
echo "docker_volume=${{ env.DOCKER_VOLUME }}" >> "$GITHUB_OUTPUT"
echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT"

View File

@ -20,8 +20,6 @@ on:
- "Dockerfile"
- "Dockerfile_amd"
- "Dockerfile_intel"
- "Dockerfile.neuron"
- "Dockerfile_gaudi"
branches:
- "main"
workflow_dispatch:
@ -39,7 +37,7 @@ jobs:
# fail-fast is true by default
fail-fast: false
matrix:
hardware: ["cuda", "cuda-trtllm", "rocm", "intel-xpu", "intel-cpu", "neuron", "gaudi"]
hardware: ["cuda", "cuda-trtllm", "rocm", "intel-xpu", "intel-cpu", "neuron"]
uses: ./.github/workflows/build.yaml # calls the one above ^
permissions:
contents: write

View File

@ -1,53 +0,0 @@
name: "Nix Build Docker image"
on:
pull_request:
push:
branches:
- 'main'
tags:
- 'v*'
concurrency:
group: nix-image-${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
build_nix_image:
runs-on:
group: aws-highmemory-32-plus-priv
steps:
- uses: actions/checkout@v4
- uses: cachix/install-nix-action@v27
with:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
with:
name: text-generation-inference
# If you chose signing key for write access
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
env:
USER: github_runner
- name: Build
run: nix build .#dockerImage
- name: Initialize Docker Buildx
uses: docker/setup-buildx-action@v3
with:
install: true
buildkitd-config: /tmp/buildkitd.toml
- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1
- name: Login to internal Container Registry
# if: github.event_name != 'pull_request'
uses: docker/login-action@v3
with:
username: ${{ secrets.REGISTRY_USERNAME }}
password: ${{ secrets.REGISTRY_PASSWORD }}
registry: registry.internal.huggingface.tech
- name: Push to docker
run: |
if [ "${{ github.event_name }}" = "pull_request" ]; then
export TAG=nix-sha-${{ env.GITHUB_SHA_SHORT }}
else
export TAG=${{ github.ref_name }}-nix
fi
export IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:$TAG
nix-shell -p skopeo --command "skopeo --insecure-policy copy docker-archive:$(readlink -f ./result) docker://$IMAGE --dest-compress-format zstd"

View File

@ -7,7 +7,6 @@ on:
- "proto/**"
- "router/**"
- "launcher/**"
- "backends/**"
- "Cargo.lock"
- "rust-toolchain.toml"
concurrency:

View File

@ -8,7 +8,6 @@ on:
- "proto/**"
- "router/**"
- "launcher/**"
- "backends/**"
- "Cargo.lock"
- "rust-toolchain.toml"
@ -47,7 +46,7 @@ jobs:
- name: Download locked kernels
run: |
source ./.venv/bin/activate
kernels download server
hf-kernels download server
- name: Run server tests
run: |
source ./.venv/bin/activate

1
.gitignore vendored
View File

@ -28,4 +28,3 @@ server/fbgemmm
hl-smi_log*.txt
.graph_dumps
out
hqt_output

774
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -21,7 +21,7 @@ default-members = [
resolver = "2"
[workspace.package]
version = "3.2.3-dev0"
version = "3.1.1"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"
@ -29,7 +29,7 @@ homepage = "https://github.com/huggingface/text-generation-inference"
[workspace.dependencies]
base64 = "0.22.0"
tokenizers = { version = "0.20.0", features = ["http"] }
hf-hub = { version = "0.4.2", features = ["tokio"] }
hf-hub = { version = "0.4.1", features = ["tokio"] }
metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
minijinja = { version = "2.2.0", features = ["json"] }

View File

@ -1,5 +1,5 @@
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.85.0 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -65,7 +65,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
ENV PATH="$PATH:/root/.local/bin"
RUN uv python install ${PYTHON_VERSION}
RUN uv venv --python ${PYTHON_VERSION} && uv pip install torch==${PYTORCH_VERSION} torchvision pip setuptools packaging
RUN uv venv --python ${PYTHON_VERSION} && uv pip install torch==${PYTORCH_VERSION} pip setuptools packaging
ENV VIRTUAL_ENV=/usr/src/.venv/
ENV PATH="$PATH:/usr/src/.venv/bin/"
@ -165,9 +165,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
git \
&& rm -rf /var/lib/apt/lists/*
# RUN curl -LsSf https://astral.sh/uv/install.sh | sh
# ENV PATH="$PATH:/root/.local/bin"
COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
ENV PATH="$PATH:/root/.local/bin"
# Install flash-attention dependencies
# RUN pip install einops --no-cache-dir
@ -184,12 +183,12 @@ COPY server server
COPY server/Makefile server/Makefile
ENV HF_KERNELS_CACHE=/kernels
RUN cd server && \
uv sync --frozen --extra gen --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --extra torch --no-install-project --active && \
uv sync --frozen --extra gen --extra attention --extra bnb --extra accelerate --extra compressed-tensors --extra marlin --extra moe --extra quantize --extra peft --extra outlines --no-install-project --active && \
make gen-server-raw && \
kernels download .
hf-kernels download .
RUN cd server && \
uv sync --frozen --extra gen --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --extra torch --active --python=${PYTHON_VERSION} && \
uv sync --frozen --extra gen --extra attention --extra bnb --extra accelerate --extra compressed-tensors --extra marlin --extra moe --extra quantize --extra peft --extra outlines --active --python=${PYTHON_VERSION} && \
uv pip install nvidia-nccl-cu12==2.25.1 && \
pwd && \
text-generation-server --help

View File

@ -5,7 +5,7 @@ RUN mkdir -p /tgi
# Fetch the optimum-neuron sources directly to avoid relying on pypi deployments
FROM alpine AS optimum-neuron
RUN mkdir -p /optimum-neuron
ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.1.0.tar.gz /optimum-neuron/sources.tar.gz
ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.0.28.tar.gz /optimum-neuron/sources.tar.gz
RUN tar -C /optimum-neuron -xf /optimum-neuron/sources.tar.gz --strip-components=1
# Build cargo components (adapted from TGI original Dockerfile)
@ -18,7 +18,8 @@ RUN apt-get update -y \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.1 --profile minimal -y
COPY rust-toolchain.toml rust-toolchain.toml
RUN curl -sSf https://sh.rustup.rs | sh -s -- -y --no-modify-path --default-toolchain none
ENV PATH="/root/.cargo/bin:${PATH}"
RUN cargo install cargo-chef --locked
@ -108,10 +109,10 @@ RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEU
# Install neuronx packages
RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \
aws-neuronx-dkms=2.19.64.0 \
aws-neuronx-collectives=2.23.135.0-3e70920f2 \
aws-neuronx-runtime-lib=2.23.112.0-9b5179492 \
aws-neuronx-tools=2.20.204.0 \
aws-neuronx-dkms=2.18.20.0 \
aws-neuronx-collectives=2.22.33.0-d2128d1aa \
aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 \
aws-neuronx-tools=2.19.0.0 \
libxml2 \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
@ -120,16 +121,16 @@ ENV PATH="/opt/bin/:/opt/aws/neuron/bin:${PATH}"
# Install manually torch CPU version to avoid pulling CUDA
RUN pip3 install \
torch==2.5.1 \
torchvision==0.20.1 \
torch==2.1.2 \
torchvision==0.16.2 \
--index-url https://download.pytorch.org/whl/cpu
RUN pip3 install \
neuronx-cc==2.16.372.0 \
torch-neuronx==2.5.1.2.4.0 \
transformers-neuronx==0.13.322 \
neuronx-distributed==0.10.1 \
libneuronxla==2.1.681.0 \
neuronx-cc==2.15.143.0 \
torch-neuronx==2.1.2.2.3.2 \
transformers-neuronx==0.12.313 \
neuronx-distributed==0.9.0 \
libneuronxla==2.0.5347.0 \
--extra-index-url=https://pip.repos.neuron.amazonaws.com
# Install HuggingFace packages

View File

@ -1,5 +1,5 @@
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.85.0 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -41,244 +41,303 @@ COPY backends backends
COPY launcher launcher
RUN cargo build --profile release-opt --frozen
FROM rocm/dev-ubuntu-22.04:6.3.1-complete AS base
# Text Generation Inference base image for RoCm
FROM rocm/dev-ubuntu-22.04:6.2 AS base
ARG HIPBLASLT_BRANCH="4d40e36"
ARG HIPBLAS_COMMON_BRANCH="7c1566b"
ARG LEGACY_HIPBLASLT_OPTION=
ARG RCCL_BRANCH="648a58d"
ARG RCCL_REPO="https://github.com/ROCm/rccl"
ARG TRITON_BRANCH="e5be006"
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
ARG PYTORCH_BRANCH="3a585126"
ARG PYTORCH_VISION_BRANCH="v0.19.1"
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="b7d29fb"
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
ARG AITER_BRANCH="21d47a9"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
ENV PATH=/opt/rocm/llvm/bin:$PATH
ENV ROCM_PATH=/opt/rocm
ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
ARG PYTHON_VERSION=3.11
RUN mkdir -p /app
WORKDIR /app
ENV DEBIAN_FRONTEND=noninteractive
# Install Python and other dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
ca-certificates \
ccache \
curl \
git \
ninja-build \
cmake \
software-properties-common \
python3.11-dev \
python3.11-venv && \
rm -rf /var/lib/apt/lists/*
build-essential \
ca-certificates \
ccache \
curl \
git \
make \
libmsgpack-dev \
libssl-dev \
llvm-dev \
g++ \
# Needed to build VLLM & flash.
rocthrust-dev \
hipsparse-dev \
hipblas-dev \
hipcub-dev \
rocblas-dev \
hiprand-dev \
hipfft-dev \
rocrand-dev \
miopen-hip-dev \
hipsolver-dev \
rccl-dev \
cmake \
python3.11-venv && \
rm -rf /var/lib/apt/lists/*
COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
ENV PATH="$PATH:/root/.local/bin"
RUN uv python install ${PYTHON_VERSION}
RUN uv venv --python ${PYTHON_VERSION} && uv pip install pip setuptools packaging
ENV VIRTUAL_ENV=/usr/src/.venv/
ENV PATH="$PATH:/usr/src/.venv/bin/"
# Keep in sync with `server/pyproject.toml
ARG MAMBA_VERSION=23.1.0-1
ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH=/opt/conda/bin:$PATH
RUN . .venv/bin/activate && pip install -U packaging cmake ninja wheel setuptools pybind11 Cython
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba
# translating Docker's TARGETPLATFORM into mamba arches
RUN case ${TARGETPLATFORM} in \
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
*) MAMBA_ARCH=x86_64 ;; \
esac && \
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
RUN chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \
mamba init && \
rm ~/mambaforge.sh
# RUN conda install intel::mkl-static intel::mkl-include
# Install pytorch
# On arm64 we exit with an error code
RUN case ${TARGETPLATFORM} in \
"linux/arm64") exit 1 ;; \
*) /opt/conda/bin/conda update -y conda && \
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
esac && \
/opt/conda/bin/conda clean -ya
# Install flash-attention, torch dependencies
RUN python3 -m pip install --upgrade pip uv && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/*
RUN conda install mkl=2021
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/
ARG COMMON_WORKDIR=/
WORKDIR ${COMMON_WORKDIR}
# Install HIPBLASLt
FROM base AS build_hipblaslt
ARG HIPBLASLT_BRANCH
ARG HIPBLAS_COMMON_BRANCH
# Set to "--legacy_hipblas_direct" for ROCm<=6.2
ARG LEGACY_HIPBLASLT_OPTION
RUN git clone https://github.com/ROCm/hipBLAS-common.git
RUN . .venv/bin/activate && cd hipBLAS-common \
&& git checkout ${HIPBLAS_COMMON_BRANCH} \
&& mkdir build \
&& cd build \
&& cmake .. \
&& make package \
&& dpkg -i ./*.deb
RUN git clone https://github.com/ROCm/hipBLASLt
RUN . .venv/bin/activate && cd hipBLASLt \
ARG HIPBLASLT_BRANCH="e6da924"
RUN git clone https://github.com/ROCm/hipBLASLt.git \
&& cd hipBLASLt \
&& git checkout ${HIPBLASLT_BRANCH} \
&& ./install.sh -d --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \
&& cd build/release \
&& make package
RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install
FROM scratch AS export_hipblaslt
ARG COMMON_WORKDIR
COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb /
# RCCL build stages
FROM base AS build_rccl
ARG RCCL_BRANCH
ARG RCCL_REPO
RUN git clone ${RCCL_REPO}
RUN . .venv/bin/activate && cd rccl \
ARG RCCL_BRANCH="rocm-6.2.0"
RUN git clone https://github.com/ROCm/rccl \
&& cd rccl \
&& git checkout ${RCCL_BRANCH} \
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install
FROM scratch AS export_rccl
ARG COMMON_WORKDIR
COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb /
# Triton build stages
FROM base AS build_triton
ARG TRITON_BRANCH
ARG TRITON_REPO
RUN git clone ${TRITON_REPO}
RUN . .venv/bin/activate && cd triton \
ARG TRITON_BRANCH="e192dba"
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \
&& cd triton \
&& git checkout ${TRITON_BRANCH} \
&& cd python \
&& python3 setup.py bdist_wheel --dist-dir=dist
RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install
FROM scratch AS export_triton
ARG COMMON_WORKDIR
COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl /
# # AMD-SMI build stages
FROM base AS build_amdsmi
RUN . .venv/bin/activate && cd /opt/rocm/share/amd_smi \
RUN cd /opt/rocm/share/amd_smi \
&& pip wheel . --wheel-dir=dist
RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install
FROM scratch AS export_amdsmi
COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl /
FROM base AS build_pytorch
ARG PYTORCH_BRANCH
ARG PYTORCH_VISION_BRANCH
ARG PYTORCH_REPO
ARG PYTORCH_VISION_REPO
ARG FA_BRANCH
ARG FA_REPO
RUN git clone ${PYTORCH_REPO} pytorch
RUN . .venv/bin/activate && cd pytorch && git checkout ${PYTORCH_BRANCH} && \
pip install -r requirements.txt && git submodule update --init --recursive \
&& python3 tools/amd_build/build_amd.py \
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
&& pip install dist/*.whl
RUN git clone ${PYTORCH_VISION_REPO} vision
RUN . .venv/bin/activate && cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
&& python3 setup.py bdist_wheel --dist-dir=dist \
&& pip install dist/*.whl
RUN git clone ${FA_REPO}
RUN . .venv/bin/activate && cd flash-attention \
&& git checkout ${FA_BRANCH} \
&& git submodule update --init \
&& MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
&& cp /app/vision/dist/*.whl /app/install \
&& cp /app/flash-attention/dist/*.whl /app/install
FROM base AS final
RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \
dpkg -i /install/*deb \
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status
RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \
dpkg -i /install/*deb \
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
. .venv/bin/activate && \
pip install /install/*.whl
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
. .venv/bin/activate && \
pip install /install/*.whl
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
. .venv/bin/activate && \
pip install /install/*.whl
FROM base as build_pytorch
ARG AITER_REPO
ARG AITER_BRANCH
RUN git clone --recursive ${AITER_REPO}
RUN . .venv/bin/activate && cd aiter \
&& git checkout ${AITER_BRANCH} \
&& git submodule update --init --recursive \
&& pip install -r requirements.txt \
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
if ls /install/*.deb; then \
dpkg -i /install/*.deb \
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
fi
RUN rm -rf /var/lib/apt/lists/*
ARG BUILD_ENVIRONMENT=pytorch-linux-jammy-rocm6.2-py3.11
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
# A commit to fix the output scaling factor issue in _scaled_mm
# Not yet in 2.5.0-rc1
ARG PYTORCH_BRANCH="cedc116"
ARG PYTORCH_VISION_BRANCH="v0.19.1"
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
RUN git clone ${PYTORCH_REPO} pytorch \
&& cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \
&& pip install -r requirements.txt --no-cache-dir \
&& python tools/amd_build/build_amd.py \
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist
FROM scratch as export_pytorch
ARG COMMON_WORKDIR
COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl /
FROM base AS install_deps
ARG COMMON_WORKDIR
# Install hipblaslt
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
if ls /install/*.deb; then \
dpkg -i /install/*.deb \
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
fi
RUN --mount=type=bind,from=export_rccl,src=/,target=/install \
if ls /install/*.deb; then \
dpkg -i /install/*.deb \
# RCCL needs to be installed twice
&& dpkg -i /install/*.deb \
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \
fi
RUN --mount=type=bind,from=export_triton,src=/,target=/install \
if ls /install/*.whl; then \
# Preemptively uninstall to prevent pip same-version no-installs
pip uninstall -y triton \
&& pip install /install/*.whl; \
fi
RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \
# Preemptively uninstall to prevent pip same-version no-installs
pip uninstall -y amdsmi \
&& pip install /install/*.whl;
RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
if ls /install/*.whl; then \
# Preemptively uninstall to prevent pip same-version no-installs
pip uninstall -y torch torchvision \
&& pip install /install/*.whl; \
fi
FROM install_deps AS kernel-builder
FROM final AS kernel-builder
# # Build vllm kernels
FROM kernel-builder AS vllm-builder
WORKDIR /usr/src
COPY server/Makefile-vllm Makefile
RUN . .venv/bin/activate && pip install setuptools_scm
RUN pip install setuptools_scm
# Build specific version of vllm
RUN . .venv/bin/activate && make build-vllm-rocm
RUN make build-vllm-rocm
# Build Flash Attention v2 kernels
FROM kernel-builder AS flash-att-v2-builder
WORKDIR /usr/src
COPY server/Makefile-flash-att-v2 Makefile
# Build specific version of flash attention v2
RUN make build-flash-attention-v2-rocm
# Build Transformers CUDA kernels (gpt-neox and bloom)
FROM kernel-builder AS custom-kernels-builder
WORKDIR /usr/src
COPY server/custom_kernels/ .
RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
RUN python setup.py build
# Build exllama kernels
FROM kernel-builder AS exllama-kernels-builder
WORKDIR /usr/src
COPY server/exllama_kernels/ .
RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
RUN python setup.py build
# Build exllama v2 kernels
FROM kernel-builder AS exllamav2-kernels-builder
WORKDIR /usr/src
COPY server/exllamav2_kernels/ .
RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
RUN python setup.py build
FROM kernel-builder AS marlin-kernels
WORKDIR /usr/src
ENV MARLIN_KERNELS_BRANCH=v0.3.6
ENV VLLM_TARGET_DEVICE=rocm
RUN . .venv/bin/activate && git clone https://github.com/danieldk/marlin-kernels.git && \
RUN git clone https://github.com/danieldk/marlin-kernels.git && \
cd marlin-kernels && \
git checkout ${MARLIN_KERNELS_BRANCH} && \
python3 setup.py bdist_wheel --dist-dir=dist
python setup.py install
FROM kernel-builder AS moe-kernels
WORKDIR /usr/src
ENV MOE_KERNELS_BRANCH=v0.8.2
ENV VLLM_TARGET_DEVICE=rocm
RUN . .venv/bin/activate && git clone https://github.com/danieldk/moe-kernels.git && \
RUN git clone https://github.com/danieldk/moe-kernels.git && \
cd moe-kernels && \
git checkout ${MOE_KERNELS_BRANCH} && \
python3 setup.py bdist_wheel --dist-dir=dist
python setup.py install
FROM final AS base-copy
FROM install_deps AS base-copy
# Text Generation Inference base env
ENV HF_HOME=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80
ENV VIRTUAL_ENV=/app/.venv/
ENV PATH="$PATH:/app/.venv/bin/"
# Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from exllama kernels builder
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from marlin kernels
COPY --from=marlin-kernels /usr/src/marlin-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from moe kernels
COPY --from=moe-kernels /usr/src/moe-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Install server
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
ENV UV_SYSTEM_PYTHON=1
RUN cd server && \
uv pip install grpcio-tools mypy-protobuf && \
uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir && \
pip install -U pip uv && \
uv sync --frozen --extra gen --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --no-install-project && \
. ./.venv/bin/activate && \
make gen-server-raw
RUN cd server && \
uv sync --frozen --extra gen --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines && \
. ./.venv/bin/activate && \
pwd && \
text-generation-server --help
RUN --mount=type=bind,from=vllm-builder,src=/app/vllm/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=custom-kernels-builder,src=/app/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=custom-kernels-builder,src=/app/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=exllama-kernels-builder,src=/app/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=exllamav2-kernels-builder,src=/app/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=marlin-kernels,src=/app/marlin-kernels/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=moe-kernels,src=/app/moe-kernels/dist,target=/install \
uv pip install /install/*.whl
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
# AWS Sagemaker compatible image
FROM base AS sagemaker
@ -309,6 +368,4 @@ COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh
ENTRYPOINT ["/tgi-entrypoint.sh"]
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/root/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib"
ENV PYTHONPATH=/app/.venv/lib/python3.11/site-packages
# CMD ["--json-output"]
CMD ["--json-output"]

View File

@ -1,9 +1,9 @@
# Those arguments are required to build the image
ARG HABANA_VERSION=1.20.0
ARG PYTORCH_VERSION=2.6.0
ARG HABANA_VERSION
ARG PYTORCH_VERSION
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.85.0 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -92,10 +92,11 @@ RUN cd server && \
make gen-server && \
pip install --no-deps -r requirements.txt && \
bash ./dill-0.3.8-patch.sh && \
pip install outlines~=0.0.34 && \
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
pip install . --no-cache-dir
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router

View File

@ -1,6 +1,6 @@
ARG PLATFORM=xpu
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.85.0 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -45,7 +45,7 @@ RUN cargo build --profile release-opt --frozen
# Text Generation Inference base image for Intel
FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04 AS xpu
FROM intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04 AS xpu
USER root
@ -87,7 +87,7 @@ RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https:/
RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc libnl-genl-3-200
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-pti-dev-0.9
# Text Generation Inference base env
ENV HF_HOME=/data \
@ -96,9 +96,13 @@ ENV HF_HOME=/data \
WORKDIR /usr/src
RUN pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/xpu
RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/torch-2.5.0a0%2Bgite84e33f-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/torchaudio-2.5.0a0%2B56bc006-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/torchvision-0.20.0a0%2B8e8a208-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/oneccl_bind_pt-2.5.0%2Bxpu-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install triton-xpu==3.0.0b2 --no-cache-dir
# Install server
COPY proto proto
@ -110,14 +114,15 @@ RUN cd server && \
pip install -U pip uv && \
uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/pti/0.9/lib:/opt/conda/lib
ENV CCL_ZE_IPC_EXCHANGE=sockets
ENV TORCH_LLM_ALLREDUCE=1
ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
#ENV TORCH_LLM_ALLREDUCE=1
#ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.6.0%2Bxpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.6.10%2Bxpu-cp311-cp311-linux_x86_64.whl
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 1ccf72b2d11cd00b47aef6d6cd054c088aa6f083
RUN cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc,ats-m150' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router

View File

@ -1,15 +1,13 @@
FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu24.04 AS deps
ARG llamacpp_version=b4827
ARG llamacpp_version=b4651
ARG llamacpp_cuda=OFF
ARG llamacpp_native=ON
ARG llamacpp_cpu_arm_arch=native
ARG cuda_arch=75-real;80-real;86-real;89-real;90-real
WORKDIR /opt/src
ENV DEBIAN_FRONTEND=noninteractive
RUN apt update && apt upgrade -y && apt install -y \
RUN apt update && apt install -y \
clang \
cmake \
curl \
@ -19,10 +17,9 @@ RUN apt update && apt upgrade -y && apt install -y \
pkg-config \
tar
ADD https://github.com/ggml-org/llama.cpp/archive/refs/tags/${llamacpp_version}.tar.gz /opt/src/
RUN mkdir -p llama.cpp \
&& tar -xzf ${llamacpp_version}.tar.gz -C llama.cpp --strip-components=1 \
&& cd llama.cpp \
ADD https://github.com/ggerganov/llama.cpp/archive/refs/tags/${llamacpp_version}.tar.gz /opt/src/
RUN tar -xzf ${llamacpp_version}.tar.gz \
&& cd llama.cpp-${llamacpp_version} \
&& cmake -B build \
-DCMAKE_INSTALL_PREFIX=/usr \
-DCMAKE_INSTALL_LIBDIR=/usr/lib \
@ -30,8 +27,6 @@ RUN mkdir -p llama.cpp \
-DCMAKE_CXX_COMPILER=clang++ \
-DCMAKE_CUDA_ARCHITECTURES=${cuda_arch} \
-DGGML_CUDA=${llamacpp_cuda} \
-DGGML_NATIVE=${llamacpp_native} \
-DGGML_CPU_ARM_ARCH=${llamacpp_cpu_arm_arch} \
-DLLAMA_BUILD_COMMON=OFF \
-DLLAMA_BUILD_TESTS=OFF \
-DLLAMA_BUILD_EXAMPLES=OFF \
@ -41,7 +36,7 @@ RUN mkdir -p llama.cpp \
WORKDIR /app
COPY rust-toolchain.toml rust-toolchain.toml
RUN curl -sSf https://sh.rustup.rs | sh -s -- --no-modify-path --default-toolchain 1.85.1 --profile minimal -y
RUN curl -sSf https://sh.rustup.rs | sh -s -- --no-modify-path --default-toolchain 1.85.0 --profile minimal -y
ENV PATH="/root/.cargo/bin:$PATH"
RUN cargo install cargo-chef --locked
@ -53,18 +48,16 @@ FROM deps AS builder
COPY --from=planner /app/recipe.json recipe.json
RUN cargo chef cook \
--recipe-path recipe.json \
--profile release \
--profile release-opt \
--package text-generation-router-llamacpp
COPY . .
RUN cargo build \
--profile release \
--profile release-opt \
--package text-generation-router-llamacpp --frozen
FROM nvidia/cuda:12.8.0-cudnn-runtime-ubuntu24.04
WORKDIR /app
ENV DEBIAN_FRONTEND=noninteractive
RUN apt update && apt upgrade -y && apt install -y \
RUN apt update && apt install -y \
python3-venv \
python3-pip
@ -72,16 +65,11 @@ RUN python3 -m venv /venv
ENV PATH="/venv/bin:$PATH"
COPY backends/llamacpp/requirements.txt requirements.txt
COPY --from=builder /opt/src/llama.cpp/gguf-py gguf-py
COPY --from=builder /opt/src/llama.cpp/convert_hf_to_gguf.py /bin/
RUN pip3 install --no-cache-dir \
-r requirements.txt \
-e gguf-py
RUN pip3 install --no-cache-dir -r requirements.txt
COPY --from=builder /usr/lib/libllama.so /usr/lib/
COPY --from=builder /usr/lib/libggml*.so /usr/lib/
COPY --from=builder /app/target/release/text-generation-router-llamacpp /usr/bin/
COPY --from=builder /app/target/release-opt/text-generation-router-llamacpp /usr/bin/
ENV HF_HUB_ENABLE_HF_TRANSFER=1

View File

@ -3,9 +3,10 @@ ARG cuda_base=12.8.0
ARG build_type=release
ARG ompi_version=4.1.7
ARG sccache_gha_enabled=off
ARG actions_results_url=""
ARG actions_cache_url=""
ARG actions_runtime_token=""
# CUDA dependent dependencies resolver stage
FROM nvidia/cuda:${cuda_base}-cudnn-devel-ubuntu24.04 AS cuda-builder
@ -65,15 +66,15 @@ WORKDIR /usr/src/text-generation-inference
ARG cuda_arch_list
ARG build_type
ARG sccache_gha_enabled
ARG actions_results_url
ARG actions_cache_url
ARG actions_runtime_token
# Install Rust
ENV PATH="/root/.cargo/bin:$PATH"
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.1 --profile minimal -y && \
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.0 --profile minimal -y && \
chmod -R a+w /root/.rustup && \
chmod -R a+w /root/.cargo && \
cargo install sccache --version ">=0.10.0" --locked
cargo install sccache --locked
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig"
@ -84,7 +85,7 @@ ENV CUDA_ARCH_LIST=${cuda_arch_list}
# SCCACHE Specifics args - before finding a better, more generic, way...
ENV SCCACHE_GHA_ENABLED=${sccache_gha_enabled}
ENV ACTIONS_RESULTS_URL=${actions_results_url}
ENV ACTIONS_CACHE_URL=${actions_cache_url}
ENV ACTIONS_RUNTIME_TOKEN=${actions_runtime_token}
COPY Cargo.lock Cargo.lock

View File

@ -53,6 +53,3 @@ run-falcon-7b-instruct-quantize:
clean:
rm -rf target aml
preview_doc:
doc-builder preview text-generation-inference docs/source --not_python_module

View File

@ -84,7 +84,7 @@ model=HuggingFaceH4/zephyr-7b-beta
volume=$PWD/data
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.2.3 --model-id $model
ghcr.io/huggingface/text-generation-inference:3.1.1 --model-id $model
```
And then you can make requests like
@ -121,7 +121,7 @@ curl localhost:8080/v1/chat/completions \
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.2.3-rocm --model-id $model` instead of the command above.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.1-rocm --model-id $model` instead of the command above.
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
```
@ -152,7 +152,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
token=<your cli READ token>
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.2.3 --model-id $model
ghcr.io/huggingface/text-generation-inference:3.1.1 --model-id $model
```
### A note on Shared Memory (shm)
@ -264,7 +264,7 @@ After that you can run TGI with `nix run`:
```shell
cd text-generation-inference
nix run --extra-experimental-features nix-command --extra-experimental-features flakes . -- --model-id meta-llama/Llama-3.1-8B-Instruct
nix run --extra-experimental-features nix-command --extra-experimental-features flakes . -- --model-idmeta-llama/Llama-3.1-8B-Instruct
```
**Note:** when you are using Nix on a non-NixOS system, you have to [make some symlinks](https://danieldk.eu/Nix-CUDA-on-non-NixOS-systems#make-runopengl-driverlib-and-symlink-the-driver-library)

View File

@ -1,9 +1,9 @@
mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
mkfile_dir := $(dir $(mkfile_path))
root_dir := ${mkfile_dir}/../..
root_dir := "${mkfile_dir}/../.."
HABANA_VERSION := 1.20.0
PYTORCH_VERSION := 2.6.0
HABANA_VERSION := 1.19.0
PYTORCH_VERSION := 2.5.1
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
@ -47,16 +47,3 @@ local-dev-install: install-dependencies
make install-server && \
make install-router && \
make install-launcher'
# In order to run the integration tests, you need to first build the image (make -C backends/gaudi image)
run-integration-tests:
uv pip install -r ${root_dir}/backends/gaudi/server/integration-tests/requirements.txt
DOCKER_VOLUME=${root_dir}/data \
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
uv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests
# This is used to capture the expected outputs for the integration tests offering an easy way to add more models to the integration tests
capture-expected-outputs-for-integration-tests:
DOCKER_VOLUME=${root_dir}/data \
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
uv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests/capture_expected_outputs.py

View File

@ -96,47 +96,3 @@ curl 127.0.0.1:8080/generate \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
-H 'Content-Type: application/json'
```
### Integration tests
To run the integration tests, you need to first build the image:
```bash
make -C backends/gaudi image
```
Then run the following command to run the integration tests:
```bash
make -C backends/gaudi run-integration-tests
```
To capture the expected outputs for the integration tests, you can run the following command:
```bash
make -C backends/gaudi capture-expected-outputs-for-integration-tests
```
#### How the integration tests works
The integration tests works as follows:
1. Start a tgi server in a container, similar to the command:
```bash
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
-p 8080:80 -v $volume:/data \
-e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \
tgi-gaudi --model-id $model \
--max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 4 --max-batch-prefill-tokens 2048
```
2. Do a /generate request to the server, similar to the command:
```bash
curl 127.0.0.1:8080/generate \
-X POST \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
-H 'Content-Type: application/json'
```
3. Check the output of the server against the expected output:
```python
assert curl_output == expected_output
```
This is the repeated for a set of models and configurations.

View File

@ -1,283 +0,0 @@
# Examples of Docker Commands for Gaudi Backend
This page gives a list of examples of docker run commands for some of the most popular models.
> **Note:** The parameters are chosen for Gaudi2 hardware to maximize performance on this given hardware, please adjust the parameters based on your hardware. For example, if you are using Gaudi3, you may want to increase the batch size.
## Default Precision (BF16)
### Llama3.1-8B on 1 card (BF16)
```bash
model=meta-llama/Meta-Llama-3.1-8B-Instruct
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e PREFILL_BATCH_BUCKET_SIZE=2 \
-e BATCH_BUCKET_SIZE=32 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
```
### Llama3.1-70B 8 cards (BF16)
```bash
model=meta-llama/Meta-Llama-3.1-70B-Instruct
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e BATCH_BUCKET_SIZE=256 \
-e PREFILL_BATCH_BUCKET_SIZE=4 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
```
### Llama2-7B on 1 Card (BF16)
```bash
model=meta-llama/Llama-2-7b-chat-hf
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e PREFILL_BATCH_BUCKET_SIZE=2 \
-e BATCH_BUCKET_SIZE=32 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
```
### Llama2-70B on 8 cards (BF16)
```bash
model=meta-llama/Llama-2-70b-chat-hf
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e BATCH_BUCKET_SIZE=256 \
-e PREFILL_BATCH_BUCKET_SIZE=4 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
```
### Llava-v1.6-Mistral-7B on 1 card (BF16)
```bash
model=llava-hf/llava-v1.6-mistral-7b-hf
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e PREFILL_BATCH_BUCKET_SIZE=1 \
-e BATCH_BUCKET_SIZE=1 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
--max-total-tokens 8192 --max-batch-size 4
```
## FP8 Precision
Please refer to the [FP8 Precision](https://huggingface.co/docs/text-generation-inference/backends/gaudi_new#how-to-use-different-precision-formats) section for more details. You need to measure the statistics of the model first before running the model in FP8 precision.
## Llama3.1-8B on 1 Card (FP8)
```bash
model=meta-llama/Meta-Llama-3.1-8B-Instruct
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e PREFILL_BATCH_BUCKET_SIZE=2 \
-e BATCH_BUCKET_SIZE=32 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
```
## Llama3.1-70B on 8 cards (FP8)
```bash
model=meta-llama/Meta-Llama-3.1-70B-Instruct
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e BATCH_BUCKET_SIZE=256 \
-e PREFILL_BATCH_BUCKET_SIZE=4 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
```
## Llama2-7B on 1 Card (FP8)
```bash
model=meta-llama/Llama-2-7b-chat-hf
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e PREFILL_BATCH_BUCKET_SIZE=2 \
-e BATCH_BUCKET_SIZE=32 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
```
## Llama2-70B on 8 Cards (FP8)
```bash
model=meta-llama/Llama-2-70b-chat-hf
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e BATCH_BUCKET_SIZE=256 \
-e PREFILL_BATCH_BUCKET_SIZE=4 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
```
## Llava-v1.6-Mistral-7B on 1 Card (FP8)
```bash
model=llava-hf/llava-v1.6-mistral-7b-hf
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e PREFILL_BATCH_BUCKET_SIZE=1 \
-e BATCH_BUCKET_SIZE=1 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
--max-total-tokens 8192 --max-batch-size 4
```
## Llava-v1.6-Mistral-7B on 8 Cards (FP8)
```bash
model=llava-hf/llava-v1.6-mistral-7b-hf
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e PREFILL_BATCH_BUCKET_SIZE=1 \
-e BATCH_BUCKET_SIZE=1 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
--max-total-tokens 8192 --max-batch-size 4
```

View File

@ -1,85 +0,0 @@
import json
import os
from typing import Dict, Any, Generator
import pytest
from test_model import TEST_CONFIGS
UNKNOWN_CONFIGS = {
name: config
for name, config in TEST_CONFIGS.items()
if config["expected_greedy_output"] == "unknown"
or config["expected_batch_output"] == "unknown"
}
@pytest.fixture(scope="module", params=UNKNOWN_CONFIGS.keys())
def test_config(request) -> Dict[str, Any]:
"""Fixture that provides model configurations for testing."""
test_config = UNKNOWN_CONFIGS[request.param]
test_config["test_name"] = request.param
return test_config
@pytest.fixture(scope="module")
def test_name(test_config):
yield test_config["test_name"]
@pytest.fixture(scope="module")
def tgi_service(launcher, test_config, test_name) -> Generator:
"""Fixture that provides a TGI service for testing."""
with launcher(test_config["model_id"], test_name) as service:
yield service
@pytest.mark.asyncio
async def test_capture_expected_outputs(tgi_service, test_config, test_name):
"""Test that captures expected outputs for models with unknown outputs."""
print(f"Testing {test_name} with {test_config['model_id']}")
# Wait for service to be ready
await tgi_service.health(1000)
client = tgi_service.client
# Test single request (greedy)
print("Testing single request...")
response = await client.generate(
test_config["input"],
max_new_tokens=32,
)
greedy_output = response.generated_text
# Test multiple requests (batch)
print("Testing batch requests...")
responses = []
for _ in range(4):
response = await client.generate(
test_config["input"],
max_new_tokens=32,
)
responses.append(response.generated_text)
# Store results in a JSON file
output_file = "server/integration-tests/expected_outputs.json"
results = {}
# Try to load existing results if file exists
if os.path.exists(output_file):
with open(output_file, "r") as f:
results = json.load(f)
# Update results for this model
results[test_name] = {
"model_id": test_config["model_id"],
"input": test_config["input"],
"greedy_output": greedy_output,
"batch_outputs": responses,
"args": test_config["args"],
}
# Save updated results
with open(output_file, "w") as f:
json.dump(results, f, indent=2)
print(f"\nResults for {test_name} saved to {output_file}")

View File

@ -1,292 +0,0 @@
import asyncio
import contextlib
import os
import shlex
import subprocess
import sys
import threading
import time
from tempfile import TemporaryDirectory
from typing import List
import socket
import docker
import pytest
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
from docker.errors import NotFound
from loguru import logger
from test_model import TEST_CONFIGS
from text_generation import AsyncClient
from text_generation.types import Response
# Use the latest image from the local docker build
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", "tgi-gaudi")
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", None)
HF_TOKEN = os.getenv("HF_TOKEN", None)
assert (
HF_TOKEN is not None
), "HF_TOKEN is not set, please set it as some models are gated and thus the test will fail without it"
if DOCKER_VOLUME is None:
logger.warning(
"DOCKER_VOLUME is not set, this will lead to the tests redownloading the models on each run, consider setting it to speed up testing"
)
LOG_LEVEL = os.getenv("LOG_LEVEL", "info")
BASE_ENV = {
"HF_HUB_ENABLE_HF_TRANSFER": "1",
"LOG_LEVEL": LOG_LEVEL,
"HF_TOKEN": os.getenv("HF_TOKEN", None),
}
HABANA_RUN_ARGS = {
"runtime": "habana",
"ipc_mode": "host",
"cap_add": ["sys_nice"],
}
logger.add(
sys.stderr,
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
level="INFO",
)
def stream_container_logs(container, test_name):
"""Stream container logs in a separate thread."""
try:
for log in container.logs(stream=True, follow=True):
print(
f"[TGI Server Logs - {test_name}] {log.decode('utf-8')}",
end="",
file=sys.stderr,
flush=True,
)
except Exception as e:
logger.error(f"Error streaming container logs: {str(e)}")
class LauncherHandle:
def __init__(self, port: int):
self.client = AsyncClient(f"http://localhost:{port}", timeout=3600)
def _inner_health(self):
raise NotImplementedError
async def health(self, timeout: int = 60):
assert timeout > 0
start_time = time.time()
logger.info(f"Starting health check with timeout of {timeout}s")
for attempt in range(timeout):
if not self._inner_health():
logger.error("Launcher crashed during health check")
raise RuntimeError("Launcher crashed")
try:
await self.client.generate("test")
elapsed = time.time() - start_time
logger.info(f"Health check passed after {elapsed:.1f}s")
return
except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e:
if attempt == timeout - 1:
logger.error(f"Health check failed after {timeout}s: {str(e)}")
raise RuntimeError(f"Health check failed: {str(e)}")
if attempt % 10 == 0 and attempt != 0: # Only log every 10th attempt
logger.debug(
f"Connection attempt {attempt}/{timeout} failed: {str(e)}"
)
time.sleep(1)
except Exception as e:
logger.error(f"Unexpected error during health check: {str(e)}")
# Get full traceback for debugging
import traceback
logger.error(f"Full traceback:\n{traceback.format_exc()}")
raise
class ContainerLauncherHandle(LauncherHandle):
def __init__(self, docker_client, container_name, port: int):
super(ContainerLauncherHandle, self).__init__(port)
self.docker_client = docker_client
self.container_name = container_name
def _inner_health(self) -> bool:
try:
container = self.docker_client.containers.get(self.container_name)
status = container.status
if status not in ["running", "created"]:
logger.warning(f"Container status is {status}")
# Get container logs for debugging
logs = container.logs().decode("utf-8")
logger.debug(f"Container logs:\n{logs}")
return status in ["running", "created"]
except Exception as e:
logger.error(f"Error checking container health: {str(e)}")
return False
class ProcessLauncherHandle(LauncherHandle):
def __init__(self, process, port: int):
super(ProcessLauncherHandle, self).__init__(port)
self.process = process
def _inner_health(self) -> bool:
return self.process.poll() is None
@pytest.fixture(scope="module")
def data_volume():
tmpdir = TemporaryDirectory()
yield tmpdir.name
try:
# Cleanup the temporary directory using sudo as it contains root files created by the container
subprocess.run(shlex.split(f"sudo rm -rf {tmpdir.name}"), check=True)
except subprocess.CalledProcessError as e:
logger.error(f"Error cleaning up temporary directory: {str(e)}")
@pytest.fixture(scope="module")
def launcher(data_volume):
@contextlib.contextmanager
def docker_launcher(
model_id: str,
test_name: str,
):
logger.info(
f"Starting docker launcher for model {model_id} and test {test_name}"
)
# Get a random available port
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
return port
port = get_free_port()
logger.debug(f"Using port {port}")
client = docker.from_env()
container_name = f"tgi-gaudi-test-{test_name.replace('/', '-')}"
try:
container = client.containers.get(container_name)
logger.info(
f"Stopping existing container {container_name} for test {test_name}"
)
container.stop()
container.wait()
except NotFound:
pass
except Exception as e:
logger.error(f"Error handling existing container: {str(e)}")
model_name = next(
name for name, cfg in TEST_CONFIGS.items() if cfg["model_id"] == model_id
)
tgi_args = TEST_CONFIGS[model_name]["args"].copy()
env = BASE_ENV.copy()
# Add model_id to env
env["MODEL_ID"] = model_id
# Add env config that is definied in the fixture parameter
if "env_config" in TEST_CONFIGS[model_name]:
env.update(TEST_CONFIGS[model_name]["env_config"].copy())
volumes = [f"{DOCKER_VOLUME}:/data"]
logger.debug(f"Using volume {volumes}")
try:
logger.info(f"Creating container with name {container_name}")
# Log equivalent docker run command for debugging, this is not actually executed
container = client.containers.run(
DOCKER_IMAGE,
command=tgi_args,
name=container_name,
environment=env,
detach=True,
volumes=volumes,
ports={"80/tcp": port},
**HABANA_RUN_ARGS,
)
logger.info(f"Container {container_name} started successfully")
# Start log streaming in a background thread
log_thread = threading.Thread(
target=stream_container_logs,
args=(container, test_name),
daemon=True, # This ensures the thread will be killed when the main program exits
)
log_thread.start()
# Add a small delay to allow container to initialize
time.sleep(2)
# Check container status after creation
status = container.status
logger.debug(f"Initial container status: {status}")
if status not in ["running", "created"]:
logs = container.logs().decode("utf-8")
logger.error(f"Container failed to start properly. Logs:\n{logs}")
yield ContainerLauncherHandle(client, container.name, port)
except Exception as e:
logger.error(f"Error starting container: {str(e)}")
# Get full traceback for debugging
import traceback
logger.error(f"Full traceback:\n{traceback.format_exc()}")
raise
finally:
try:
container = client.containers.get(container_name)
logger.info(f"Stopping container {container_name}")
container.stop()
container.wait()
container_output = container.logs().decode("utf-8")
print(container_output, file=sys.stderr)
container.remove()
logger.info(f"Container {container_name} removed successfully")
except NotFound:
pass
except Exception as e:
logger.warning(f"Error cleaning up container: {str(e)}")
return docker_launcher
@pytest.fixture(scope="module")
def generate_load():
async def generate_load_inner(
client: AsyncClient, prompt: str, max_new_tokens: int, n: int
) -> List[Response]:
try:
futures = [
client.generate(
prompt,
max_new_tokens=max_new_tokens,
decoder_input_details=True,
)
for _ in range(n)
]
return await asyncio.gather(*futures)
except Exception as e:
logger.error(f"Error generating load: {str(e)}")
raise
return generate_load_inner

View File

@ -1,2 +0,0 @@
[pytest]
asyncio_mode = auto

View File

@ -1,7 +0,0 @@
pytest >= 8.3.5
pytest-asyncio >= 0.26.0
docker >= 7.1.0
Levenshtein >= 0.27.1
loguru >= 0.7.3
aiohttp >= 3.11.14
text-generation

View File

@ -1,276 +0,0 @@
from typing import Any, Dict
from text_generation import AsyncClient
import pytest
from Levenshtein import distance as levenshtein_distance
# The "args" config is not optimized for speed but only check that the inference is working for the different models architectures
TEST_CONFIGS = {
"meta-llama/Llama-3.1-8B-Instruct-shared": {
"model_id": "meta-llama/Llama-3.1-8B-Instruct",
"input": "What is Deep Learning?",
"expected_greedy_output": " A Beginners Guide\nDeep learning is a subset of machine learning that involves the use",
"expected_batch_output": " A Beginners Guide\nDeep learning is a subset of machine learning that involves the use",
"args": [
"--sharded",
"true",
"--num-shard",
"8",
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"8",
"--max-batch-prefill-tokens",
"2048",
],
},
"meta-llama/Llama-3.1-8B-Instruct": {
"model_id": "meta-llama/Llama-3.1-8B-Instruct",
"input": "What is Deep Learning?",
"expected_greedy_output": " A Beginners Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
"expected_batch_output": " A Beginners Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
"env_config": {},
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
"--max-batch-prefill-tokens",
"2048",
],
},
"meta-llama/Llama-2-7b-chat-hf": {
"model_id": "meta-llama/Llama-2-7b-chat-hf",
"input": "What is Deep Learning?",
"expected_greedy_output": "\n\nDeep learning (also known as deep structured learning) is part of a broader family of machine learning techniques based on artificial neural networks\u2014specific",
"expected_batch_output": "\n\nDeep learning (also known as deep structured learning) is part of a broader family of machine learning techniques based on artificial neural networks\u2014specific",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
"--max-batch-prefill-tokens",
"2048",
],
},
"mistralai/Mistral-7B-Instruct-v0.3": {
"model_id": "mistralai/Mistral-7B-Instruct-v0.3",
"input": "What is Deep Learning?",
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning in artificial intelligence (AI) that has networks capable of learning unsupervised from data that is unstructured",
"expected_batch_output": "\n\nDeep learning is a subset of machine learning in artificial intelligence (AI) that has networks capable of learning unsupervised from data that is unstructured",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
"--max-batch-prefill-tokens",
"2048",
],
},
"bigcode/starcoder2-3b": {
"model_id": "bigcode/starcoder2-3b",
"input": "What is Deep Learning?",
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to perform tasks.\n\nNeural networks are a type of machine learning algorithm that",
"expected_batch_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to perform tasks.\n\nNeural networks are a type of machine learning algorithm that",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
"--max-batch-prefill-tokens",
"2048",
],
},
"google/gemma-7b-it": {
"model_id": "google/gemma-7b-it",
"input": "What is Deep Learning?",
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from large amounts of data. Neural networks are inspired by the structure and function of",
"expected_batch_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from large amounts of data. Neural networks are inspired by the structure and function of",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
"--max-batch-prefill-tokens",
"2048",
],
},
"Qwen/Qwen2-0.5B-Instruct": {
"model_id": "Qwen/Qwen2-0.5B-Instruct",
"input": "What is Deep Learning?",
"expected_greedy_output": " Deep Learning is a type of machine learning that is based on the principles of artificial neural networks. It is a type of machine learning that is used to train models",
"expected_batch_output": " Deep Learning is a type of machine learning that is based on the principles of artificial neural networks. It is a type of machine learning that is used to train models",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
"--max-batch-prefill-tokens",
"2048",
],
},
"tiiuae/falcon-7b-instruct": {
"model_id": "tiiuae/falcon-7b-instruct",
"input": "What is Deep Learning?",
"expected_greedy_output": "\nDeep learning is a branch of machine learning that uses artificial neural networks to learn and make decisions. It is based on the concept of hierarchical learning, where a",
"expected_batch_output": "\nDeep learning is a branch of machine learning that uses artificial neural networks to learn and make decisions. It is based on the concept of hierarchical learning, where a",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
],
},
"microsoft/phi-1_5": {
"model_id": "microsoft/phi-1_5",
"input": "What is Deep Learning?",
"expected_greedy_output": "\n\nDeep Learning is a subfield of Machine Learning that focuses on building neural networks with multiple layers of interconnected nodes. These networks are designed to learn from large",
"expected_batch_output": "\n\nDeep Learning is a subfield of Machine Learning that focuses on building neural networks with multiple layers of interconnected nodes. These networks are designed to learn from large",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
],
},
"openai-community/gpt2": {
"model_id": "openai-community/gpt2",
"input": "What is Deep Learning?",
"expected_greedy_output": "\n\nDeep learning is a new field of research that has been around for a long time. It is a new field of research that has been around for a",
"expected_batch_output": "\n\nDeep learning is a new field of research that has been around for a long time. It is a new field of research that has been around for a",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
],
},
"facebook/opt-125m": {
"model_id": "facebook/opt-125m",
"input": "What is Deep Learning?",
"expected_greedy_output": "\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout",
"expected_batch_output": "\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
],
},
"EleutherAI/gpt-j-6b": {
"model_id": "EleutherAI/gpt-j-6b",
"input": "What is Deep Learning?",
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning that is based on the idea of neural networks. Neural networks are a type of artificial intelligence that is inspired by",
"expected_batch_output": "\n\nDeep learning is a subset of machine learning that is based on the idea of neural networks. Neural networks are a type of artificial intelligence that is inspired by",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
],
},
}
print(f"Testing {len(TEST_CONFIGS)} models")
@pytest.fixture(scope="module", params=TEST_CONFIGS.keys())
def test_config(request) -> Dict[str, Any]:
"""Fixture that provides model configurations for testing."""
test_config = TEST_CONFIGS[request.param]
test_config["test_name"] = request.param
return test_config
@pytest.fixture(scope="module")
def model_id(test_config):
yield test_config["model_id"]
@pytest.fixture(scope="module")
def test_name(test_config):
yield test_config["test_name"]
@pytest.fixture(scope="module")
def expected_outputs(test_config):
return {
"greedy": test_config["expected_greedy_output"],
# "sampling": model_config["expected_sampling_output"],
"batch": test_config["expected_batch_output"],
}
@pytest.fixture(scope="module")
def input(test_config):
return test_config["input"]
@pytest.fixture(scope="module")
def tgi_service(launcher, model_id, test_name):
with launcher(model_id, test_name) as tgi_service:
yield tgi_service
@pytest.fixture(scope="module")
async def tgi_client(tgi_service) -> AsyncClient:
await tgi_service.health(1000)
return tgi_service.client
@pytest.mark.asyncio
async def test_model_single_request(
tgi_client: AsyncClient, expected_outputs: Dict[str, Any], input: str
):
# Bounded greedy decoding without input
response = await tgi_client.generate(
input,
max_new_tokens=32,
)
assert response.details.generated_tokens == 32
assert response.generated_text == expected_outputs["greedy"]
@pytest.mark.asyncio
async def test_model_multiple_requests(
tgi_client, generate_load, expected_outputs, input
):
num_requests = 4
responses = await generate_load(
tgi_client,
input,
max_new_tokens=32,
n=num_requests,
)
assert len(responses) == 4
expected = expected_outputs["batch"]
for r in responses:
assert r.details.generated_tokens == 32
# Compute the similarity with the expectation using the levenshtein distance
# We should not have more than two substitutions or additions
assert levenshtein_distance(r.generated_text, expected) < 3

File diff suppressed because it is too large Load Diff

View File

@ -9,30 +9,30 @@ text-generation-server = 'text_generation_server.cli:app'
[tool.poetry.dependencies]
python = ">=3.9,<3.13"
protobuf = "^5.0"
grpcio = "^1.71.1"
protobuf = "^3.20.3"
grpcio = "^1.51.1"
grpcio-status = "*"
grpcio-reflection = "*"
grpc-interceptor = "^0.15.0"
typer = "^0.15.0"
loguru = "^0.7.3"
opentelemetry-api = "^1.32.0"
opentelemetry-exporter-otlp = "^1.32.0"
opentelemetry-instrumentation-grpc = "^0.53b0"
hf-transfer = "^0.1.9"
sentencepiece = "^0.2.0"
peft = "^0.15"
optimum-habana = "1.17"
transformers = "^4.49"
numpy = "^1.26"
accelerate = "^0.33"
typer = "^0.7.0"
loguru = "^0.6.0"
opentelemetry-api = "^1.15.0"
opentelemetry-exporter-otlp = "^1.15.0"
opentelemetry-instrumentation-grpc = "^0.36b0"
hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97"
peft = "^0.10"
optimum-habana = "1.15.0"
transformers = "4.45.2"
numpy = "1.26.4"
accelerate = "0.33.0"
outlines= { version = "^0.0.36", optional = true }
prometheus-client = "^0.21.1"
prometheus-client = "^0.20.0"
py-cpuinfo = "^9.0.0"
[tool.poetry.group.dev.dependencies]
grpcio-tools = "*"
pytest = "^8.3.5"
pytest = "^7.3.0"
[tool.pytest.ini_options]
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
@ -40,6 +40,3 @@ markers = ["private: marks tests as requiring an admin hf token (deselect with '
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.poetry.requires-plugins]
poetry-plugin-export = ">=1.8"

View File

@ -1,101 +1,89 @@
accelerate==0.33.0 ; python_version >= "3.9" and python_version < "3.13"
annotated-types==0.7.0 ; python_version >= "3.9" and python_version < "3.13"
attrs==25.3.0 ; python_version >= "3.9" and python_version < "3.13"
certifi==2025.1.31 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.1 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.8 ; python_version >= "3.9" and python_version < "3.13"
cloudpickle==3.1.1 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Windows" or python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
deprecated==1.2.18 ; python_version >= "3.9" and python_version < "3.13"
aiohappyeyeballs==2.4.3 ; python_version >= "3.9" and python_version < "3.13"
aiohttp==3.10.10 ; python_version >= "3.9" and python_version < "3.13"
aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13"
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.11"
attrs==24.2.0 ; python_version >= "3.9" and python_version < "3.13"
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
coloredlogs==15.0.1 ; python_version >= "3.9" and python_version < "3.13"
datasets==3.0.1 ; python_version >= "3.9" and python_version < "3.13"
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
diffusers==0.31.0 ; python_version >= "3.9" and python_version < "3.13"
diskcache==5.6.3 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.18.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2025.3.2 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.70.0 ; python_version >= "3.9" and python_version < "3.13"
dill==0.3.7 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
frozenlist==1.4.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec[http]==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.71.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.71.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.72.0rc1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.9 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.30.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.48.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.48.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.26.1 ; python_version >= "3.9" and python_version < "3.13"
humanfriendly==10.0 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==8.6.1 ; python_version >= "3.9" and python_version < "3.13"
interegular==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
jinja2==3.1.6 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==8.5.0 ; python_version >= "3.9" and python_version < "3.13"
jinja2==3.1.4 ; python_version >= "3.9" and python_version < "3.13"
joblib==1.4.2 ; python_version >= "3.9" and python_version < "3.13"
jsonschema-specifications==2024.10.1 ; python_version >= "3.9" and python_version < "3.13"
jsonschema==4.23.0 ; python_version >= "3.9" and python_version < "3.13"
lark==1.2.2 ; python_version >= "3.9" and python_version < "3.13"
llvmlite==0.43.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.7.3 ; python_version >= "3.9" and python_version < "3.13"
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markupsafe==3.0.2 ; python_version >= "3.9" and python_version < "3.13"
mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13"
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
nest-asyncio==1.6.0 ; python_version >= "3.9" and python_version < "3.13"
multidict==6.1.0 ; python_version >= "3.9" and python_version < "3.13"
multiprocess==0.70.15 ; python_version >= "3.9" and python_version < "3.13"
networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
numba==0.60.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
nvidia-cublas-cu12==12.4.5.8 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cuda-cupti-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cuda-nvrtc-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cuda-runtime-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cudnn-cu12==9.1.0.70 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cufft-cu12==11.2.1.3 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-curand-cu12==10.3.5.147 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cusolver-cu12==11.6.1.9 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cusparse-cu12==12.3.1.170 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cusparselt-cu12==0.6.2 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-nccl-cu12==2.21.5 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-nvjitlink-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-nvtx-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
opentelemetry-api==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-common==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
optimum-habana==1.17.0 ; python_version >= "3.9" and python_version < "3.13"
optimum==1.24.0 ; python_version >= "3.9" and python_version < "3.13"
outlines==0.0.36 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.2 ; python_version >= "3.9" and python_version < "3.13"
peft==0.15.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==11.2.1 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
protobuf==5.29.4 ; python_version >= "3.9" and python_version < "3.13"
psutil==7.0.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
optimum-habana==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
optimum==1.23.2 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pandas==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
peft==0.10.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==11.0.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
propcache==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==3.20.3 ; python_version >= "3.9" and python_version < "3.13"
psutil==6.1.0 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pydantic-core==2.33.1 ; python_version >= "3.9" and python_version < "3.13"
pydantic==2.11.3 ; python_version >= "3.9" and python_version < "3.13"
pygments==2.19.1 ; python_version >= "3.9" and python_version < "3.13"
pyarrow==17.0.0 ; python_version >= "3.9" and python_version < "3.13"
pyreadline3==3.5.4 ; sys_platform == "win32" and python_version >= "3.9" and python_version < "3.13"
python-dateutil==2.9.0.post0 ; python_version >= "3.9" and python_version < "3.13"
pytz==2024.2 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
referencing==0.36.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.11.6 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==14.0.0 ; python_version >= "3.9" and python_version < "3.13"
rpds-py==0.24.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.5.3 ; python_version >= "3.9" and python_version < "3.13"
scikit-learn==1.6.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scikit-learn==1.5.2 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentence-transformers==3.3.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==78.1.0 ; python_version >= "3.12" and python_version < "3.13"
shellingham==1.5.4 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
threadpoolctl==3.6.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
torch==2.6.0 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.67.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.49.0 ; python_version >= "3.9" and python_version < "3.13"
triton==3.2.0 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
typer==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.13.2 ; python_version >= "3.9" and python_version < "3.13"
typing-inspection==0.4.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.4.0 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.2.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.17.2 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.21.0 ; python_version >= "3.9" and python_version < "3.13"
sentence-transformers[train]==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
six==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.12.1 ; python_version >= "3.9" and python_version < "3.13"
threadpoolctl==3.5.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
transformers[sentencepiece]==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.13" and python_version >= "3.9"
typer==0.7.0 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
tzdata==2024.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
xxhash==3.5.0 ; python_version >= "3.9" and python_version < "3.13"
yarl==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -0,0 +1,23 @@
import pytest
import os
from text_generation_server.pb import generate_pb2
os.environ["USE_PREFIX_CACHING"] = "1"
os.environ["ATTENTION"] = "flashinfer"
@pytest.fixture
def default_pb_parameters():
return generate_pb2.NextTokenChooserParameters(
temperature=1.0,
repetition_penalty=1.0,
top_k=0,
top_p=1.0,
typical_p=1.0,
do_sample=False,
)
@pytest.fixture
def default_pb_stop_parameters():
return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10)

View File

@ -0,0 +1,363 @@
import pytest
import torch
from copy import copy
from transformers import AutoTokenizer
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.utils import weight_hub_files, download_weights
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
@pytest.fixture(scope="session")
def default_bloom():
model_id = "bigscience/bloom-560m"
revision = "main"
filenames = weight_hub_files(model_id, revision, ".safetensors")
download_weights(filenames, model_id, revision)
return BLOOMSharded(
model_id,
model_class=BloomForCausalLM,
)
@pytest.fixture(scope="session")
def bloom_560m_tokenizer():
return AutoTokenizer.from_pretrained("bigscience/bloom-560m", padding_side="left")
@pytest.fixture
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@pytest.fixture
def default_pb_batch(default_pb_request):
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)
@pytest.fixture
def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer):
return BloomCausalLMBatch.from_pb(
default_pb_batch, bloom_560m_tokenizer, torch.float32, torch.device("cpu")
)
@pytest.fixture
def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer):
req_0 = copy(default_pb_request)
req_0.id = 1
req_1 = default_pb_request
req_1.id = 2
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
return BloomCausalLMBatch.from_pb(
batch_pb, bloom_560m_tokenizer, torch.float32, torch.device("cpu")
)
def test_batch_from_pb(default_pb_batch, default_bloom_batch):
batch = default_bloom_batch
assert batch.batch_id == default_pb_batch.id
assert batch.requests == default_pb_batch.requests
assert len(batch.input_ids) == default_pb_batch.size
assert batch.input_ids[0][-1] == 10264
assert torch.all(batch.input_ids[0][:-1] == 3)
assert batch.attention_mask[0][0] == 1
assert torch.all(batch.attention_mask[0][1:] == 0)
assert batch.past_key_values is None
assert all(
[
torch.equal(input_ids, all_input_ids[:, 0])
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
]
)
assert batch.input_lengths == [1]
assert len(batch) == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
assert batch.max_input_length == batch.input_lengths[0]
def test_batch_concatenate_no_prefill(default_bloom_batch):
with pytest.raises(ValueError):
BloomCausalLMBatch.concatenate([default_bloom_batch, default_bloom_batch])
def test_causal_lm_batch_type(default_bloom):
assert default_bloom.batch_type == BloomCausalLMBatch
def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
sequence_length = len(default_bloom_batch.all_input_ids[0])
generations, next_batch, _ = default_bloom.generate_token(default_bloom_batch)
assert len(generations) == len(default_bloom_batch)
assert isinstance(next_batch, CausalLMBatch)
assert not next_batch.keys_head_dim_last
assert len(next_batch.all_input_ids) == len(next_batch)
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
assert len(next_batch.attention_mask[0]) == 11
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
assert torch.all(next_batch.all_input_ids[0][:-2] == 3)
assert torch.all(next_batch.attention_mask[0][:2] == 1)
assert torch.all(next_batch.attention_mask[0][2:] == 0)
assert next_batch.input_ids.shape == (len(next_batch), 1)
assert next_batch.input_ids[0, 0] == 10264
assert next_batch.input_lengths == [2]
assert next_batch.max_input_length == next_batch.input_lengths[0]
assert next_batch.past_key_values is not None
assert all(
[p[0].shape == (16, 64, sequence_length) for p in next_batch.past_key_values]
)
assert all(
[p[1].shape == (16, sequence_length, 64) for p in next_batch.past_key_values]
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all(
[
token_id.item() == 10264
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == "Test"
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
next_batch = default_bloom_batch
for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(default_bloom_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert (
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
)
assert generations[0].request_id == default_bloom_batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
== default_bloom_batch.stopping_criterias[0].max_new_tokens
)
def test_causal_lm_generate_token_completion_multi(
default_bloom, default_multi_requests_bloom_batch
):
next_batch = default_multi_requests_bloom_batch
for i in range(
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(default_multi_requests_bloom_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
assert generations[1].generated_text.text == "TestTestTestTestTest"
assert (
generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id
)
assert (
generations[1].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
)
# Copy stopping_criterias before filtering
stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy()
next_batch = next_batch.filter([next_batch.requests[0].id])
for _ in range(
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert (
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
)
assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
)
assert (
generations[0].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
)
def test_batch_concatenate(
default_bloom, default_bloom_batch, default_multi_requests_bloom_batch
):
next_batch_0 = default_bloom_batch
_, next_batch_0, _ = default_bloom.generate_token(next_batch_0)
_, next_batch_0, _ = default_bloom.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_bloom_batch
_, next_batch_1, _ = default_bloom.generate_token(next_batch_1)
# Clone past_key_values before concatenating to compare after,
# because they are removed from the concatenated batches
next_batch_0_past_key_values = [
(k.clone(), v.clone()) for (k, v) in next_batch_0.past_key_values
]
next_batch_1_past_key_values = [
(k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values
]
next_batch = BloomCausalLMBatch.concatenate([next_batch_0, next_batch_1])
assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
assert torch.all(
next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1
)
assert torch.all(
next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1
)
assert torch.all(next_batch.attention_mask[1:, 3:] == 0)
assert next_batch.batch_id == 0
assert torch.all(next_batch.input_ids == 10264)
assert next_batch.input_lengths == [3, 2, 2]
assert next_batch.max_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == list(next_batch_1.requests)
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers
assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]
assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias
assert next_batch.past_key_values is not None
assert all([p[0].shape == (3, 16, 64, 2) for p in next_batch.past_key_values])
assert all([p[1].shape == (3, 16, 2, 64) for p in next_batch.past_key_values])
for i, past in enumerate(next_batch.past_key_values):
assert torch.equal(next_batch_0_past_key_values[i][0][:, :, -2:], past[0][0])
assert torch.equal(
next_batch_1_past_key_values[i][0][:, :, -1:],
past[0][1:, :, :, -1].reshape(-1, 64, 1),
)
assert torch.equal(next_batch_0_past_key_values[i][1][:, -2:, :], past[1][0])
assert torch.equal(
next_batch_1_past_key_values[i][1][:, -1:, :],
past[1][1:, :, -1, :].reshape(-1, 1, 64),
)
for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2
):
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 3
assert generations[2].generated_text.text == "TestTestTestTestTest"
assert (
generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id
)
assert (
generations[2].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
for _ in range(
default_bloom_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 2
):
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
assert (
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
)
assert generations[0].request_id == default_bloom_batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
== default_bloom_batch.stopping_criterias[0].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[1].id])
for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
- default_bloom_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 4
):
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert (
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
)
assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
)
assert (
generations[0].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
)

View File

@ -0,0 +1,354 @@
import pytest
import torch
from copy import copy
from transformers import AutoTokenizer
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
@pytest.fixture(scope="session")
def default_causal_lm():
return CausalLM.fallback("gpt2")
@pytest.fixture(scope="session")
def gpt2_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
tokenizer.pad_token_id = 50256
return tokenizer
@pytest.fixture
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@pytest.fixture
def default_pb_batch(default_pb_request):
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)
@pytest.fixture
def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
return CausalLMBatch.from_pb(
default_pb_batch, gpt2_tokenizer, torch.float32, torch.device("cpu")
)
@pytest.fixture
def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
req_0 = copy(default_pb_request)
req_0.id = 1
req_1 = default_pb_request
req_1.id = 2
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)
return CausalLMBatch.from_pb(
batch_pb, gpt2_tokenizer, torch.float32, torch.device("cpu")
)
def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
batch = default_causal_lm_batch
assert batch.batch_id == default_pb_batch.id
assert batch.requests == default_pb_batch.requests
assert len(batch.input_ids) == default_pb_batch.size
assert batch.input_ids[0][-1] == 14402
assert torch.all(batch.input_ids[0][:-1] == 50256)
assert batch.attention_mask[0, 0] == 1
assert torch.all(batch.attention_mask[0, 1:] == 0)
assert batch.past_key_values is None
assert all(
[
torch.equal(input_ids, all_input_ids[:, 0])
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
]
)
assert batch.input_lengths == [1]
assert len(batch) == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
assert batch.max_input_length == batch.input_lengths[0]
def test_batch_concatenate_no_prefill(default_causal_lm_batch):
with pytest.raises(ValueError):
CausalLMBatch.concatenate([default_causal_lm_batch, default_causal_lm_batch])
def test_causal_lm_batch_type(default_causal_lm):
assert default_causal_lm.batch_type == CausalLMBatch
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
sequence_length = len(default_causal_lm_batch.all_input_ids[0])
generations, next_batch, _ = default_causal_lm.generate_token(
default_causal_lm_batch
)
assert len(generations) == len(next_batch)
assert isinstance(next_batch, CausalLMBatch)
assert len(next_batch.all_input_ids) == len(next_batch)
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
assert len(next_batch.attention_mask[0]) == 11
assert next_batch.all_input_ids[0][-1] == 13
assert next_batch.all_input_ids[0][-2] == 14402
assert torch.all(next_batch.all_input_ids[0][:-2] == 50256)
assert torch.all(next_batch.attention_mask[0][0:2] == 1)
assert torch.all(next_batch.attention_mask[0][2:] == 0)
assert next_batch.input_ids.shape == (len(next_batch), 1)
assert next_batch.input_ids[0, 0] == 13
assert next_batch.input_lengths == [2]
assert next_batch.max_input_length == next_batch.input_lengths[0]
assert next_batch.past_key_values is not None
assert all(
[p[0].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
)
assert all(
[p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all(
[
token_id.item() == 13
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == "."
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0
def test_causal_lm_generate_token_completion(
default_causal_lm, default_causal_lm_batch
):
next_batch = default_causal_lm_batch
for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
)
def test_causal_lm_generate_token_completion_multi(
default_causal_lm, default_multi_requests_causal_lm_batch
):
next_batch = default_multi_requests_causal_lm_batch
for i in range(
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
assert generations[1].generated_text.text == ".java:784)"
assert (
generations[1].request_id
== default_multi_requests_causal_lm_batch.requests[1].id
)
assert (
generations[1].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
# Copy stopping_criterias before filtering
stopping_criterias = (
default_multi_requests_causal_lm_batch.stopping_criterias.copy()
)
next_batch = next_batch.filter([next_batch.requests[0].id])
for _ in range(
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert (
generations[0].request_id
== default_multi_requests_causal_lm_batch.requests[0].id
)
assert (
generations[0].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
)
def test_batch_concatenate(
default_causal_lm, default_causal_lm_batch, default_multi_requests_causal_lm_batch
):
next_batch_0 = default_causal_lm_batch
_, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0)
_, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_causal_lm_batch
_, next_batch_1, _ = default_causal_lm.generate_token(next_batch_1)
# Clone past_key_values before concatenating to compare after,
# because they are removed from the concatenated batches
next_batch_0_past_key_values = [
(k.clone(), v.clone()) for (k, v) in next_batch_0.past_key_values
]
next_batch_1_past_key_values = [
(k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values
]
next_batch = CausalLMBatch.concatenate([next_batch_0, next_batch_1])
assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
assert torch.all(
next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1
)
assert torch.all(
next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1
)
assert torch.all(next_batch.attention_mask[1:, 3:] == 0)
assert next_batch.batch_id == 0
assert next_batch.input_ids[0, 0] == 12355
assert torch.all(next_batch.input_ids[1:] == 13)
assert next_batch.input_lengths == [3, 2, 2]
assert next_batch.max_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == list(next_batch_1.requests)
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers
assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]
assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias
assert next_batch.past_key_values is not None
assert all([p[0].shape == (3, 12, 2, 64) for p in next_batch.past_key_values])
assert all([p[1].shape == (3, 12, 2, 64) for p in next_batch.past_key_values])
for i, past in enumerate(next_batch.past_key_values):
assert torch.equal(next_batch_0_past_key_values[i][0][0, :, -2:], past[0][0])
assert torch.equal(
next_batch_1_past_key_values[i][0][:, :, -1:], past[0][1:, :, -1:, :]
)
assert torch.equal(next_batch_0_past_key_values[i][1][0, :, -2:], past[1][0])
assert torch.equal(
next_batch_1_past_key_values[i][1][:, :, -1:], past[1][1:, :, -1:, :]
)
for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2
):
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 3
assert generations[2].generated_text.text == ".java:784)"
assert (
generations[2].request_id
== default_multi_requests_causal_lm_batch.requests[1].id
)
assert (
generations[2].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
for _ in range(
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 2
):
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[1].id])
for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
- default_causal_lm_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 4
):
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert (
generations[0].request_id
== default_multi_requests_causal_lm_batch.requests[0].id
)
assert (
generations[0].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
)

View File

@ -0,0 +1,83 @@
import pytest
import torch
from transformers import AutoTokenizer
from text_generation_server.models import Model
def get_test_model():
class TestModel(Model):
def batch_type(self):
raise NotImplementedError
def generate_token(self, batch):
raise NotImplementedError
tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")
model = TestModel(
"test_model_id",
torch.nn.Linear(1, 1),
tokenizer,
False,
torch.float32,
torch.device("cpu"),
)
return model
@pytest.mark.private
def test_decode_streaming_english_spaces():
model = get_test_model()
truth = "Hello here, this is a simple test"
all_input_ids = [15043, 1244, 29892, 445, 338, 263, 2560, 1243]
assert (
all_input_ids == model.tokenizer(truth, add_special_tokens=False)["input_ids"]
)
decoded_text = ""
offset = 0
token_offset = 0
for i in range(len(all_input_ids)):
text, offset, token_offset = model.decode_token(
all_input_ids[: i + 1], offset, token_offset
)
decoded_text += text
assert decoded_text == truth
@pytest.mark.private
def test_decode_streaming_chinese_utf8():
model = get_test_model()
truth = "我很感谢你的热情"
all_input_ids = [
30672,
232,
193,
139,
233,
135,
162,
235,
179,
165,
30919,
30210,
234,
134,
176,
30993,
]
decoded_text = ""
offset = 0
token_offset = 0
for i in range(len(all_input_ids)):
text, offset, token_offset = model.decode_token(
all_input_ids[: i + 1], offset, token_offset
)
decoded_text += text
assert decoded_text == truth

View File

@ -0,0 +1,108 @@
import pytest
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch, CausalLM
@pytest.fixture(scope="session")
def default_santacoder():
return CausalLM.fallback(model_id="bigcode/santacoder")
@pytest.fixture
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="def",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="def")]),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@pytest.fixture
def default_pb_batch(default_pb_request):
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)
@pytest.fixture
def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
input_chunks=generate_pb2.Input(
chunks=[
generate_pb2.InputChunk(
text="<fim-prefix>def<fim-suffix>world<fim-middle>"
)
]
),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@pytest.fixture
def default_fim_pb_batch(default_fim_pb_request):
return generate_pb2.Batch(id=0, requests=[default_fim_pb_request], size=1)
@pytest.mark.skip
def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch):
batch = CausalLMBatch.from_pb(
default_pb_batch,
default_santacoder.tokenizer,
default_santacoder.dtype,
default_santacoder.device,
)
next_batch = batch
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == " test_get_all_users_with_"
assert generations[0].request_id == batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
== batch.stopping_criterias[0].max_new_tokens
)
@pytest.mark.skip
def test_fim_santacoder_generate_token_completion(
default_santacoder, default_fim_pb_batch
):
batch = CausalLMBatch.from_pb(
default_fim_pb_batch,
default_santacoder.tokenizer,
default_santacoder.dtype,
default_santacoder.device,
)
next_batch = batch
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert (
generations[0].generated_text.text
== """ineProperty(exports, "__esModule", { value"""
)
assert generations[0].request_id == batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
== batch.stopping_criterias[0].max_new_tokens
)

View File

@ -0,0 +1,365 @@
import pytest
import torch
from copy import copy
from transformers import AutoTokenizer
from text_generation_server.pb import generate_pb2
from text_generation_server.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
@pytest.fixture(scope="session")
def mt0_small_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(
"bigscience/mt0-small", padding_side="left"
)
tokenizer.bos_token_id = 0
return tokenizer
@pytest.fixture(scope="session")
def default_seq2seq_lm():
return Seq2SeqLM.fallback("bigscience/mt0-small")
@pytest.fixture
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@pytest.fixture
def default_pb_batch(default_pb_request):
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)
@pytest.fixture
def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer):
return Seq2SeqLMBatch.from_pb(
default_pb_batch, mt0_small_tokenizer, torch.float32, torch.device("cpu")
)
@pytest.fixture
def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokenizer):
req_0 = copy(default_pb_request)
req_0.id = 1
req_1 = default_pb_request
req_1.id = 2
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
return Seq2SeqLMBatch.from_pb(
batch_pb, mt0_small_tokenizer, torch.float32, torch.device("cpu")
)
def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
batch = default_seq2seq_lm_batch
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
assert batch.batch_id == default_pb_batch.id
assert batch.requests == default_pb_batch.requests
assert batch.input_ids.shape == (default_pb_batch.size, sequence_length)
assert batch.input_ids[0][-2] == 4268
assert batch.input_ids[0][-1] == 1
assert torch.all(batch.input_ids[0][:-2] == 0)
assert torch.all(batch.attention_mask[0][-2:] == 1)
assert torch.all(batch.attention_mask[0][:-2] == 0)
assert len(batch.decoder_input_ids) == default_pb_batch.size
assert batch.decoder_attention_mask is None
assert batch.encoder_last_hidden_state is None
assert batch.past_key_values is None
assert batch.input_lengths == [2]
assert batch.decoder_input_lengths == [1]
assert len(batch) == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
assert batch.max_input_length == batch.input_lengths[0]
assert batch.max_decoder_input_length == batch.decoder_input_lengths[0]
def test_batch_concatenate_no_prefill(default_seq2seq_lm_batch):
with pytest.raises(ValueError):
Seq2SeqLMBatch.concatenate([default_seq2seq_lm_batch, default_seq2seq_lm_batch])
def test_seq2seq_lm_batch_type(default_seq2seq_lm):
assert default_seq2seq_lm.batch_type == Seq2SeqLMBatch
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
generations, next_batch, _ = default_seq2seq_lm.generate_token(
default_seq2seq_lm_batch
)
assert len(generations) == len(next_batch)
assert isinstance(next_batch, Seq2SeqLMBatch)
assert next_batch.input_ids is None
assert torch.equal(
next_batch.attention_mask, default_seq2seq_lm_batch.attention_mask
)
assert next_batch.input_lengths == default_seq2seq_lm_batch.input_lengths
assert next_batch.max_input_length == default_seq2seq_lm_batch.max_input_length
assert (
next_batch.next_token_choosers == default_seq2seq_lm_batch.next_token_choosers
)
assert next_batch.stopping_criterias == default_seq2seq_lm_batch.stopping_criterias
assert len(next_batch.decoder_input_ids) == len(next_batch)
assert next_batch.all_decoder_input_ids[0][0] == 0
assert next_batch.all_decoder_input_ids[0][1] == 259
assert next_batch.decoder_attention_mask is None
assert next_batch.encoder_last_hidden_state.shape == (1, sequence_length, 512)
assert next_batch.decoder_input_lengths == [2]
assert next_batch.max_decoder_input_length == 2
assert next_batch.past_key_values is not None
assert all(
[p[0].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values]
)
assert all(
[p[1].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values]
)
assert all(
[
p[2].shape == (len(next_batch), 6, sequence_length, 64)
for p in next_batch.past_key_values
]
)
assert all(
[
p[3].shape == (len(next_batch), 6, sequence_length, 64)
for p in next_batch.past_key_values
]
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all(
[
token_id.item() == 259
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == " "
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0
def test_seq2seq_lm_generate_token_completion(
default_seq2seq_lm, default_seq2seq_lm_batch
):
next_batch = default_seq2seq_lm_batch
for _ in range(6):
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == "a few weeks"
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
assert generations[0].generated_text.generated_tokens == 7
def test_seq2seq_lm_generate_token_completion_multi(
default_seq2seq_lm, default_multi_requests_seq2seq_lm_batch
):
next_batch = default_multi_requests_seq2seq_lm_batch
for i in range(4):
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
assert generations[1].generated_text.text == "a few "
assert (
generations[1].request_id
== default_multi_requests_seq2seq_lm_batch.requests[1].id
)
assert generations[1].generated_text.generated_tokens == 5
next_batch = next_batch.filter([next_batch.requests[0].id])
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == "a few weeks"
assert (
generations[0].request_id
== default_multi_requests_seq2seq_lm_batch.requests[0].id
)
assert generations[0].generated_text.generated_tokens == 7
def test_batch_concatenate(
default_seq2seq_lm,
default_seq2seq_lm_batch,
default_multi_requests_seq2seq_lm_batch,
):
next_batch_0 = default_seq2seq_lm_batch
_, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
_, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_seq2seq_lm_batch
_, next_batch_1, _ = default_seq2seq_lm.generate_token(next_batch_1)
# Copy hidden state because it is removed from the concatenated branches
next_batch_0_encoder_last_hidden_state = next_batch_0.encoder_last_hidden_state
next_batch_1_encoder_last_hidden_state = next_batch_1.encoder_last_hidden_state
# Clone past_key_values before concatenating to compare after,
# because they are removed from the concatenated batches
next_batch_0_past_key_values = [
[t.clone() for t in layer] for layer in next_batch_0.past_key_values
]
next_batch_1_past_key_values = [
[t.clone() for t in layer] for layer in next_batch_1.past_key_values
]
next_batch = Seq2SeqLMBatch.concatenate([next_batch_0, next_batch_1])
assert next_batch.batch_id == 0
assert torch.equal(
next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0]
)
assert next_batch.all_decoder_input_ids[1][0] == 0
assert next_batch.all_decoder_input_ids[2][0] == 0
assert torch.equal(
next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids
)
assert torch.all(next_batch.decoder_attention_mask[0, :3] == 1)
assert torch.all(next_batch.decoder_attention_mask[0, 3:] == 0)
assert torch.all(next_batch.decoder_attention_mask[1:, 0] == 0)
assert torch.all(next_batch.decoder_attention_mask[1:, 1:3] == 1)
assert torch.equal(
next_batch.encoder_last_hidden_state[0],
next_batch_0_encoder_last_hidden_state[0, -2:],
)
assert torch.equal(
next_batch.encoder_last_hidden_state[1:],
next_batch_1_encoder_last_hidden_state[:, -2:],
)
assert next_batch.input_lengths == [2, 2, 2]
assert next_batch.decoder_input_lengths == [3, 2, 2]
assert next_batch.max_input_length == 2
assert next_batch.max_decoder_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == list(next_batch_1.requests)
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers
assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]
assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias
assert next_batch.past_key_values is not None
assert all(
[p[0].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
)
assert all(
[p[1].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
)
assert all(
[p[2].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
)
assert all(
[p[3].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
)
for i, past in enumerate(next_batch.past_key_values):
assert torch.equal(next_batch_0_past_key_values[i][0][0, :, -2:, :], past[0][0])
assert torch.equal(
next_batch_1_past_key_values[i][0][:, :, -1:, :], past[0][1:, :, -1:, :]
)
assert torch.equal(next_batch_0_past_key_values[i][1][0, :, -2:, :], past[1][0])
assert torch.equal(
next_batch_1_past_key_values[i][1][:, :, -1:, :], past[1][1:, :, -1:, :]
)
assert torch.equal(next_batch_0_past_key_values[i][2][0, :, -2:, :], past[2][0])
assert torch.equal(
next_batch_1_past_key_values[i][2][:, :, -2:, :], past[2][1:]
)
assert torch.equal(next_batch_0_past_key_values[i][3][0, :, -2:, :], past[3][0])
assert torch.equal(
next_batch_1_past_key_values[i][3][:, :, -2:, :], past[3][1:]
)
for _ in range(3):
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 3
assert generations[2].generated_text.text == "a few "
assert (
generations[2].request_id
== default_multi_requests_seq2seq_lm_batch.requests[1].id
)
assert generations[2].generated_text.generated_tokens == 5
next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
assert generations[0].generated_text.text == "a few weeks"
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
assert generations[0].generated_text.generated_tokens == 7
next_batch = next_batch.filter([next_batch.requests[1].id])
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == "a few weeks"
assert (
generations[0].request_id
== default_multi_requests_seq2seq_lm_batch.requests[0].id
)
assert generations[0].generated_text.generated_tokens == 7

View File

@ -0,0 +1,247 @@
import pytest
from unittest.mock import Mock
from text_generation_server.utils.adapter import (
get_attn_weights,
get_mlp_weights,
parse_lora_adapters,
AdapterInfo,
)
def test_parse_lora_adapters_empty():
assert parse_lora_adapters(None) == []
assert parse_lora_adapters("") == []
def test_parse_lora_adapters_single():
result = parse_lora_adapters("adapter1")
assert result == [AdapterInfo(id="adapter1", path=None, revision=None)]
def test_parse_lora_adapters_with_path():
result = parse_lora_adapters("adapter1=path/to/adapter1")
assert result == [
AdapterInfo(id="adapter1", path="path/to/adapter1", revision=None)
]
def test_parse_lora_adapters_with_path_and_revision():
result = parse_lora_adapters("adapter1=path/to/adapter1@main")
assert result == [
AdapterInfo(id="adapter1", path="path/to/adapter1", revision="main")
]
def test_parse_lora_adapters_multiple():
result = parse_lora_adapters(
"adapter1,adapter2=path/to/adapter2,adapter3=path/to/adapter3@dev"
)
assert result == [
AdapterInfo(id="adapter1", path=None, revision=None),
AdapterInfo(id="adapter2", path="path/to/adapter2", revision=None),
AdapterInfo(id="adapter3", path="path/to/adapter3", revision="dev"),
]
def test_parse_lora_adapters_invalid_format():
try:
parse_lora_adapters("adapter1,invalid=format=test,adapter3")
assert False, "Should have raised ValueError"
except ValueError as e:
assert str(e) == "Invalid LoRA adapter format: invalid=format=test"
def test_get_attn_weights():
# create a mock layer
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()
# call the function
result = get_attn_weights(2, mock_layer)
# assert the result
expected = {
(2, "q_proj"): (
"model.layers.2.self_attn.q_proj",
mock_layer.self_attn.query_key_value,
),
(2, "k_proj"): (
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "qkv_proj"): (
"model.layers.2.self_attn.qkv_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,
),
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
}
assert result == expected
def test_get_mlp_weights_with_gate_up_proj():
# create a mock layer with gate_up_proj
mock_layer = Mock()
mock_layer.mlp.gate_up_proj = Mock()
mock_layer.mlp.down_proj = Mock()
# call the function
result = get_mlp_weights(3, mock_layer)
# assert the result
expected = {
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
}
assert result == expected
def test_get_mlp_weights_without_gate_up_proj():
# create a mock layer without gate_up_proj
mock_layer = Mock()
mock_layer.mlp = Mock(spec=[])
# call the function
result = get_mlp_weights(1, mock_layer)
# assert the result
assert result == {}
@pytest.mark.parametrize("layer_index", [0, 1, 5])
def test_get_attn_weights_different_layers(layer_index):
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()
result = get_attn_weights(layer_index, mock_layer)
for k in ["q", "k", "v"]:
assert (layer_index, f"{k}_proj") in result
assert (
result[(layer_index, f"{k}_proj")][0]
== f"model.layers.{layer_index}.self_attn.{k}_proj"
)
assert (layer_index, "o_proj") in result
assert (
result[(layer_index, "o_proj")][0]
== f"model.layers.{layer_index}.self_attn.o_proj"
)
@pytest.mark.parametrize("layer_index", [0, 1, 5])
def test_get_mlp_weights_different_layers(layer_index):
mock_layer = Mock()
mock_layer.mlp.gate_up_proj = Mock()
mock_layer.mlp.down_proj = Mock()
result = get_mlp_weights(layer_index, mock_layer)
for k in ["gate", "up", "down"]:
assert (layer_index, f"{k}_proj") in result
assert (
result[(layer_index, f"{k}_proj")][0]
== f"model.layers.{layer_index}.mlp.{k}_proj"
)
def test_get_attn_weights_llama_compatibility():
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()
result = get_attn_weights(2, mock_layer)
expected = {
(2, "q_proj"): (
"model.layers.2.self_attn.q_proj",
mock_layer.self_attn.query_key_value,
),
(2, "k_proj"): (
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "qkv_proj"): (
"model.layers.2.self_attn.qkv_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,
),
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
}
assert result == expected
def test_get_mlp_weights_llama_compatibility():
mock_layer = Mock()
mock_layer.mlp.gate_up_proj = Mock()
mock_layer.mlp.down_proj = Mock()
result = get_mlp_weights(3, mock_layer)
expected = {
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
}
assert result == expected
def test_get_attn_weights_gemma_compatibility():
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()
result = get_attn_weights(2, mock_layer)
expected = {
(2, "q_proj"): (
"model.layers.2.self_attn.q_proj",
mock_layer.self_attn.query_key_value,
),
(2, "k_proj"): (
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "qkv_proj"): (
"model.layers.2.self_attn.qkv_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,
),
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
}
assert result == expected
def test_get_mlp_weights_gemma_compatibility():
mock_layer = Mock()
mock_layer.mlp.gate_proj = Mock()
mock_layer.mlp.up_proj = Mock()
mock_layer.mlp.down_proj = Mock()
# ensure that the mock_layer.mlp.gate_up_proj attribute does not exist.
# This is necessary because the use of `Mock` automatically creates any
# attributes that are accessed, even if they don't exist in the actual
# implementation. If `gate_up_proj` were created, `get_mlp_weights` might
# follow the wrong execution path and return an incorrect result.
del mock_layer.mlp.gate_up_proj
result = get_mlp_weights(3, mock_layer)
expected = {
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
}
assert result == expected

View File

@ -0,0 +1,21 @@
from text_generation_server.utils.hub import (
download_weights,
weight_hub_files,
weight_files,
)
from text_generation_server.utils.convert import convert_files
def test_convert_files():
model_id = "bigscience/bloom-560m"
pt_filenames = weight_hub_files(model_id, extension=".bin")
local_pt_files = download_weights(pt_filenames, model_id)
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
]
convert_files(local_pt_files, local_st_files, discard_names=[])
found_st_files = weight_files(model_id)
assert all([p in found_st_files for p in local_st_files])

View File

@ -0,0 +1,103 @@
import os
import tempfile
import pytest
import huggingface_hub.constants
import text_generation_server.utils.hub
from text_generation_server.utils.hub import (
weight_hub_files,
download_weights,
weight_files,
EntryNotFoundError,
LocalEntryNotFoundError,
RevisionNotFoundError,
)
@pytest.fixture()
def offline():
current_value = text_generation_server.utils.hub.HF_HUB_OFFLINE
text_generation_server.utils.hub.HF_HUB_OFFLINE = True
yield "offline"
text_generation_server.utils.hub.HF_HUB_OFFLINE = current_value
@pytest.fixture()
def fresh_cache():
with tempfile.TemporaryDirectory() as d:
current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d
os.environ["HUGGINGFACE_HUB_CACHE"] = d
yield
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value
os.environ["HUGGINGFACE_HUB_CACHE"] = current_value
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value
@pytest.fixture()
def prefetched():
model_id = "bert-base-uncased"
huggingface_hub.snapshot_download(
repo_id=model_id,
revision="main",
local_files_only=False,
repo_type="model",
allow_patterns=["*.safetensors"],
)
yield model_id
def test_weight_hub_files_offline_error(offline, fresh_cache):
# If the model is not prefetched then it will raise an error
with pytest.raises(EntryNotFoundError):
weight_hub_files("gpt2")
def test_weight_hub_files_offline_ok(prefetched, offline):
# If the model is prefetched then we should be able to get the weight files from local cache
filenames = weight_hub_files(prefetched)
root = None
assert len(filenames) == 1
for f in filenames:
curroot, filename = os.path.split(f)
if root is None:
root = curroot
else:
assert root == curroot
assert filename == "model.safetensors"
def test_weight_hub_files():
filenames = weight_hub_files("bigscience/bloom-560m")
assert filenames == ["model.safetensors"]
def test_weight_hub_files_llm():
filenames = weight_hub_files("bigscience/bloom")
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
def test_weight_hub_files_empty():
with pytest.raises(EntryNotFoundError):
weight_hub_files("bigscience/bloom", extension=".errors")
def test_download_weights():
model_id = "bigscience/bloom-560m"
filenames = weight_hub_files(model_id)
files = download_weights(filenames, model_id)
local_files = weight_files("bigscience/bloom-560m")
assert files == local_files
def test_weight_files_revision_error():
with pytest.raises(RevisionNotFoundError):
weight_files("bigscience/bloom-560m", revision="error")
def test_weight_files_not_cached_error(fresh_cache):
with pytest.raises(LocalEntryNotFoundError):
weight_files("bert-base-uncased")

View File

@ -0,0 +1,82 @@
import torch
from text_generation_server.layers import (
TensorParallelEmbedding,
)
class ProcessGroup:
def __init__(self, rank: int, world_size: int):
self._rank = rank
self.world_size = world_size
def size(self) -> int:
return self.world_size
def rank(self) -> int:
return self._rank
class Weights:
def __init__(self, rank: int, world_size: int, vocab_size: int, hidden_dim: int):
self.weight = (
torch.arange(vocab_size * hidden_dim).float().view(vocab_size, hidden_dim)
)
self.process_group = ProcessGroup(rank, world_size)
def get_partial_sharded(self, name: str, dim: int):
assert dim == 0
rank = self.process_group.rank()
world_size = self.process_group.size()
size = self.weight.shape[dim]
block_size = (size + world_size - 1) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
return self.weight[start:stop]
def get_shape(self, name: str):
return self.weight.shape
def test_weight_hub_files_offline_error():
vocab_size = 17
weights = Weights(
rank=0,
world_size=1,
vocab_size=vocab_size,
hidden_dim=256,
)
embeddings = TensorParallelEmbedding("", weights)
input_ids = torch.arange(vocab_size)
output = embeddings.forward(input_ids)
assert embeddings.min_id == 0
assert embeddings.max_id == 17
torch.testing.assert_close(output, torch.arange(256 * 17).float().view(17, 256))
weights_0_2 = Weights(rank=0, world_size=2, vocab_size=vocab_size, hidden_dim=256)
weights_1_2 = Weights(rank=1, world_size=2, vocab_size=vocab_size, hidden_dim=256)
embeddings_0_2 = TensorParallelEmbedding("", weights_0_2, reduce=False)
assert embeddings_0_2.min_id == 0
assert embeddings_0_2.max_id == 9
torch.testing.assert_close(
embeddings_0_2.weight,
torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0)
.view(10, 256)
.float(),
)
embeddings_1_2 = TensorParallelEmbedding("", weights_1_2, reduce=False)
assert embeddings_1_2.min_id == 9
assert embeddings_1_2.max_id == 17
torch.testing.assert_close(
embeddings_1_2.weight,
torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0)
.view(9, 256)
.float(),
)
output_tp_0 = embeddings_0_2.forward(input_ids)
output_tp_1 = embeddings_1_2.forward(input_ids)
torch.testing.assert_close(output, output_tp_0 + output_tp_1)

View File

@ -0,0 +1,88 @@
import torch
from text_generation_server.utils.tokens import (
StopSequenceCriteria,
StoppingCriteria,
FinishReason,
batch_top_tokens,
)
def test_stop_sequence_criteria():
criteria = StopSequenceCriteria("/test;")
assert not criteria("/")
assert not criteria("/test")
assert criteria("/test;")
assert not criteria("/test; ")
def test_stop_sequence_criteria_escape():
criteria = StopSequenceCriteria("<|stop|>")
assert not criteria("<")
assert not criteria("<|stop")
assert criteria("<|stop|>")
assert not criteria("<|stop|> ")
def test_stopping_criteria():
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(65827, "/test") == (False, None)
assert criteria(30, ";") == (True, FinishReason.FINISH_REASON_STOP_SEQUENCE)
def test_stopping_criteria_eos():
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(1, "") == (False, None)
assert criteria(0, "") == (True, FinishReason.FINISH_REASON_EOS_TOKEN)
def test_stopping_criteria_max():
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
def test_batch_top_tokens():
top_n_tokens = [0, 2, 3, 4, 5]
top_n_tokens_tensor = torch.tensor(top_n_tokens)
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5)
accepted_ids = torch.ones_like(top_n_tokens_tensor)
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
)
assert topn_tok_ids[0] == [[]]
assert topn_tok_ids[1] == [[0, 3]]
assert topn_tok_ids[2] == [[0, 3, 1, 4]]
assert topn_tok_ids[3] == [[0, 3, 1, 4]]
assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]
assert topn_tok_logprobs[0] == [[]]
assert topn_tok_logprobs[1] == [[-1, -2]]
assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]
assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]
assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]
# Now let's make second member of the batch be speculated
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5 * 2)
accepted_ids[1] = 2
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
)
assert topn_tok_ids[0] == [[]]
assert topn_tok_ids[1] == [[0, 3], [0, 3]]
assert topn_tok_ids[2] == [[0, 3, 1, 4]]
assert topn_tok_ids[3] == [[0, 3, 1, 4]]
assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]
assert topn_tok_logprobs[0] == [[]]
assert topn_tok_logprobs[1] == [[-1, -2], [-1, -2]]
assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]
assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]
assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]

View File

@ -0,0 +1,54 @@
# test_watermark_logits_processor.py
import os
import numpy as np
import torch
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
GAMMA = os.getenv("WATERMARK_GAMMA", 0.5)
DELTA = os.getenv("WATERMARK_DELTA", 2.0)
def test_seed_rng():
input_ids = [101, 2036, 3731, 102, 2003, 103]
processor = WatermarkLogitsProcessor()
processor._seed_rng(input_ids)
assert isinstance(processor.rng, torch.Generator)
def test_get_greenlist_ids():
input_ids = [101, 2036, 3731, 102, 2003, 103]
processor = WatermarkLogitsProcessor()
result = processor._get_greenlist_ids(input_ids, 10, torch.device("cpu"))
assert max(result) <= 10
assert len(result) == int(10 * 0.5)
def test_calc_greenlist_mask():
processor = WatermarkLogitsProcessor()
scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]])
greenlist_token_ids = torch.tensor([2, 3])
result = processor._calc_greenlist_mask(scores, greenlist_token_ids)
assert result.tolist() == [[False, False, False, False], [False, False, True, True]]
assert result.shape == scores.shape
def test_bias_greenlist_logits():
processor = WatermarkLogitsProcessor()
scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]])
green_tokens_mask = torch.tensor(
[[False, False, True, True], [False, False, False, True]]
)
greenlist_bias = 2.0
result = processor._bias_greenlist_logits(scores, green_tokens_mask, greenlist_bias)
assert np.allclose(result.tolist(), [[0.5, 0.3, 2.2, 2.8], [0.1, 0.2, 0.7, 2.9]])
assert result.shape == scores.shape
def test_call():
input_ids = [101, 2036, 3731, 102, 2003, 103]
processor = WatermarkLogitsProcessor()
scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]])
result = processor(input_ids, scores)
assert result.shape == scores.shape

File diff suppressed because it is too large Load Diff

View File

@ -16,9 +16,15 @@ app = typer.Typer()
class Quantization(str, Enum):
bitsandbytes = "bitsandbytes"
bitsandbytes_nf4 = "bitsandbytes-nf4"
bitsandbytes_fp4 = "bitsandbytes-fp4"
gptq = "gptq"
awq = "awq"
eetq = "eetq"
exl2 = "exl2"
fp8 = "fp8"
marlin = "marlin"
class Dtype(str, Enum):
@ -99,9 +105,6 @@ def serve(
"bitsandbytes",
"bitsandbytes-nf4",
"bitsandbytes-fp4",
"gptq",
"awq",
"fp8",
}:
raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
@ -109,7 +112,7 @@ def serve(
logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype))
if sharded and os.getenv("ATTENTION", "default") not in {"paged"}:
if sharded:
tgi_file = Path(__file__).resolve().parent / "tgi_service.py"
num_shard = int(os.getenv("WORLD_SIZE", "1"))
logger.info("CLI SHARDED = {}".format(num_shard))

View File

@ -1,28 +1,43 @@
from .common import (
Seqlen,
HPUPagedAttentionMetadata,
trim_attn_metadata,
trim_seqlen_metadata,
)
from text_generation_server.utils.import_utils import SYSTEM
import os
from .hpu import (
SUPPORTS_WINDOWING,
attention,
paged_attention,
)
from .common import Seqlen
if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":
from .cuda import (
attention,
paged_attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
PREFILL_IN_KV_CACHE,
)
elif SYSTEM == "rocm":
from .rocm import (
attention,
paged_attention,
reshape_and_cache,
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
)
elif SYSTEM == "ipex":
from .ipex import (
attention,
paged_attention,
reshape_and_cache,
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
)
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
from .kv_cache import KVCache, get_kv_scales
__all__ = [
"attention",
"get_kv_scales",
"paged_attention",
"reshape_and_cache",
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"KVCache",
"Seqlen",
"HPUPagedAttentionMetadata",
"trim_seqlen_metadata",
"trim_attn_metadata",
]

View File

@ -1,147 +1,72 @@
from dataclasses import dataclass
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
import torch
from typing import Optional, List, Dict
import collections
_TYPE_CACHE = {}
from typing import Optional
@dataclass
class HPUPagedAttentionMetadata:
"""Metadata for PagedAttention."""
if ATTENTION in {"flashinfer", "flashdecoding"}:
block_list: Optional[torch.Tensor]
block_mapping: Optional[torch.Tensor]
block_usage: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor]
attn_bias: Optional[torch.Tensor]
@dataclass
class Seqlen:
input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
max_q: int
max_k: int
def __init__(
self,
input_lengths,
prefix_lengths,
cu_seqlen_q=None,
max_q=None,
max_k=None,
):
self.input_lengths = input_lengths
self.prefix_lengths = prefix_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
if cu_seqlen_q is None:
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
max_q = 1
else:
assert max_q is not None
assert max_k is not None
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
def subtuple(
obj: object,
typename: str,
to_copy: List[str],
to_override: Optional[Dict[str, object]] = None,
):
if obj is None:
return None
if to_override is None:
to_override = {}
fields = set(to_copy) | set(to_override.keys())
if isinstance(obj, dict):
values = {key: obj[key] for key in fields if key in obj}
else:
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
if typename not in _TYPE_CACHE:
_TYPE_CACHE[typename] = collections.namedtuple(typename, " ".join(fields))
return _TYPE_CACHE[typename](**values)
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
total = self.input_lengths + self.prefix_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
self.max_q = max_q
self.max_k = max_k
def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
# NOTE(kzawora): To anyone working on this in the future:
# Trimming metadata is required when using HPUGraphs.
# Attention metadata is going to be hashed by PT bridge, and
# appropriate HPUGraphs will be matched based on all inputs' hash.
def clamp(self, max):
# Flash decoding doesn't need to clamp
return self
# Before you put more keys in here, make sure you know their
# value type and make sure you know how it's going to be hashed.
# You can find that information in input_hash function
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
else:
# If you use primitive types here - they will get hashed based
# on their value. You *will* get lots of excessive graph captures
# (and an OOM eventually) if you decide to put something like
# seq_len int here.
# If you absolutely need a scalar, put it in a tensor. Tensors
# get hashed using their metadata, not their values:
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
# input_hash(123) != input_hash(321)
# input_hash("abc") != input_hash("cba")
attention_metadata = subtuple(
metadata,
"TrimmedAttentionMetadata",
[
"block_list",
"block_mapping",
"block_usage",
"block_scales",
"block_groups",
"attn_bias",
],
)
return attention_metadata
@dataclass
class Seqlen:
input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: torch.Tensor
max_q: int
max_k: int
@dataclass
class Seqlen:
input_lengths: torch.Tensor
cache_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
def __init__(
self,
input_lengths,
cache_lengths,
cu_seqlen_q=None,
):
self.input_lengths = input_lengths
self.cache_lengths = cache_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
if cu_seqlen_q is None:
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
total = self.input_lengths + self.cache_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
def clamp(self, max):
# Flash decoding doesn't need to clamp
return self
def trim_seqlen_metadata(metadata: Seqlen) -> object:
# NOTE(kzawora): To anyone working on this in the future:
# Trimming metadata is required when using HPUGraphs.
# Attention metadata is going to be hashed by PT bridge, and
# appropriate HPUGraphs will be matched based on all inputs' hash.
# Before you put more keys in here, make sure you know their
# value type and make sure you know how it's going to be hashed.
# You can find that information in input_hash function
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
# If you use primitive types here - they will get hashed based
# on their value. You *will* get lots of excessive graph captures
# (and an OOM eventually) if you decide to put something like
# seq_len int here.
# If you absolutely need a scalar, put it in a tensor. Tensors
# get hashed using their metadata, not their values:
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
# input_hash(123) != input_hash(321)
# input_hash("abc") != input_hash("cba")
attention_metadata = subtuple(
metadata,
"TrimmedSeqlen",
[
"input_lengths",
"cache_lengths",
"cu_seqlen_q",
"cu_seqlen_k",
],
)
return attention_metadata
def clamp(self, max):
if SYSTEM == "rocm":
return self
raise NotImplementedError("Not implemented seqlen for paged")
return Seqlen(torch.clamp(self.input_lengths, max=max))

View File

@ -0,0 +1,357 @@
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import (
ATTENTION,
BLOCK_SIZE,
)
from text_generation_server.layers.attention import Seqlen
from typing import Optional
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512
try:
from vllm._C import cache_ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if ATTENTION in {"flashdecoding", "flashinfer"}:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
softcap: Optional[float] = None,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# value_cache => [num_blocks, num_heads, head_size, block_size]
# block_size = value_cache.shape[3]
block_size = BLOCK_SIZE
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import decode_state
return decode_state.get().forward(
query.contiguous(),
paged_kv_cache=(key_cache, value_cache),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
)
elif ATTENTION == "flashdecoding":
max_q = 1
max_k = max_s
import flash_attn_2_cuda
# TODO fixme when flash contains the fix.
# Number of splits is not correctly handled
# by the current path
# https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
if softcap is None:
softcap = 0.0
out = flash_attn_2_cuda.varlen_fwd(
query,
key_cache,
value_cache,
None,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None, # pad_k
None,
block_tables,
None,
max_q,
max_k,
0.0, # dropout
softmax_scale,
False, # zero_tensors
True, # causal
-1, # Window_left
-1, # Window right
softcap,
False, # return softmax
None, # generator
)
return out[0]
else:
if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping")
input_lengths = seqlen.input_lengths
from vllm._C import ops
out = torch.empty_like(query)
use_v1 = max_s <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512
)
if use_v1:
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
return out
try:
is_ampere_or_newer = major >= 8 and minor >= 0
if not is_ampere_or_newer:
raise ImportError("FlashAttention only supports Ampere GPUs or newer.")
import flash_attn_2_cuda
V2 = True
except ImportError:
try:
import flash_attn_cuda
V2 = False
except ImportError as e:
if major >= 8:
architecture_suffix = f"-{SYSTEM}"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
elif is_sm75:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
else:
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
SUPPORTS_WINDOWING = V2
if ATTENTION == "flashinfer":
def attention(
q: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
from text_generation_server.layers.attention.flashinfer import (
prefill_with_paged_kv_state,
)
return prefill_with_paged_kv_state.get().forward(
q.contiguous(),
causal=causal,
paged_kv_cache=(key_cache, value_cache),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
window_left=window_size_left,
)
elif V2:
def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
out = torch.empty_like(q)
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
None,
block_tables,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
softcap,
False,
None,
)[0]
else:
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap=None,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
if softcap is not None:
raise NotImplementedError("softcap is only available with flash attn v2")
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
out = torch.empty_like(q)
flash_attn_cuda.fwd(
q,
k,
v,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
False,
0,
None,
)
return out
# Prefill in the cache with every kind of attention, unless we
# have a configuration that requires flash-attention v1, which
# does not support block tables.
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2

View File

@ -0,0 +1,813 @@
#!/usr/bin/env python
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
(https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
Features supported:
1) Fwd with causal masking
2) Any sequence lengths without padding (currently fwd kernel only)
3) Support for different sequence lengths for q and k
4) Nested tensor API currently does not support dropout or bias.
Not currently supported:
1) Non power of two head dims
"""
import torch
import triton
import triton.language as tl
torch_dtype: tl.constexpr = torch.float16
@triton.jit
def cdiv_fn(x, y):
return (x + y - 1) // y
@triton.jit
def max_fn(x, y):
return tl.math.max(x, y)
@triton.jit
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
ms = tl.arange(0, m)
ns = tl.arange(0, n)
return philox_offset + ms[:, None] * stride + ns[None, :]
@triton.jit
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_offsets = dropout_offsets(
philox_seed, philox_offset, dropout_p, m, n, stride
).to(tl.uint32)
# TODO: use tl.randint for better performance
return tl.rand(philox_seed, rng_offsets)
@triton.jit
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)
rng_keep = rng_output > dropout_p
return rng_keep
@triton.jit
def load_fn(block_ptr, first, second, pad):
if first and second:
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
elif first:
tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad)
elif second:
tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad)
else:
tensor = tl.load(block_ptr)
return tensor
@triton.jit
def _attn_fwd_inner(
acc,
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
start_m,
actual_seqlen_k,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
block_min,
block_max,
offs_n_causal,
masked_blocks,
n_extra_tokens,
bias_ptr,
IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr,
OFFS_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr,
MASK_STEPS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr,
PADDED_HEAD: tl.constexpr,
):
# loop over k, v, and update accumulator
for start_n in range(block_min, block_max, BLOCK_N):
# For padded blocks, we will overrun the tensor size if
# we load all BLOCK_N. For others, the blocks are all within range.
k = load_fn(
K_block_ptr,
PADDED_HEAD,
MASK_STEPS and (n_extra_tokens != 0),
"zero",
)
if PRE_LOAD_V:
v = load_fn(
V_block_ptr,
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# We start from end of seqlen_k so only the first iteration would need
# to be checked for padding if it is not a multiple of block_n
# TODO: This can be optimized to only be true for the padded block.
if MASK_STEPS: # noqa: SIM102
# If this is the last block / iteration, we want to
# mask if the sequence length is not a multiple of block size
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps
# if not is_modulo_mn. last step might get wasted but that is okay.
# check if this masking works for that case.
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)
size_n = start_n + OFFS_N[None, :]
mask = size_n < boundary_m[:, None]
qk = tl.where(mask, qk, float("-inf"))
if IS_CAUSAL:
causal_boundary = start_n + offs_n_causal
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
qk = tl.where(causal_mask, qk, float("-inf"))
# -- compute qk ----
qk += tl.dot(q, k)
if bias_ptr is not None:
bias = load_fn(
bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero"
)
# While bias is added after multiplying qk with sm_scale, our
# optimization to use 2^x instead of e^x results in an additional
# scale factor of log2(e) which we must also multiply the bias with.
qk += bias * 1.44269504089
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk = qk - m_ij[:, None]
p = tl.math.exp2(qk)
# CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1)
if ENABLE_DROPOUT:
philox_offset = (
batch_philox_offset
+ start_m * BLOCK_M * actual_seqlen_k
+ start_n
- BLOCK_N
)
keep = dropout_mask(
philox_seed,
philox_offset,
dropout_p,
BLOCK_M,
BLOCK_N,
actual_seqlen_k,
)
if RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty),
)
p = tl.where(keep, p, 0.0)
elif RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
p.to(encoded_softmax_block_ptr.type.element_ty),
)
# -- update output accumulator --
alpha = tl.math.exp2(m_i - m_ij)
acc = acc * alpha[:, None]
if not PRE_LOAD_V:
v = load_fn(
V_block_ptr,
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
# -- update m_i and l_i
l_i = l_i * alpha + l_ij
# update m_i and l_i
m_i = m_ij
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
if bias_ptr is not None:
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.advance(
encoded_softmax_block_ptr, (0, BLOCK_N)
)
return acc, l_i, m_i
@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 64,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": True,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 64,
"BLOCK_N": 64,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 32,
"BLOCK_N": 32,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
# TODO: This config fails with head_size not pow2 with data mismatches.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config(
{
"BLOCK_M": 16,
"BLOCK_N": 16,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
],
key=["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"],
)
@triton.jit
def attn_fwd(
Q,
K,
V,
bias,
sm_scale,
L,
Out,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
stride_oz,
stride_oh,
stride_om,
stride_on,
stride_bz,
stride_bh,
stride_bm,
stride_bn,
cu_seqlens_q,
cu_seqlens_k,
dropout_p,
philox_seed,
philox_offset_base,
encoded_softmax,
HQ: tl.constexpr,
HK: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr,
VARLEN: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr,
BIAS_TYPE: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr,
):
start_m = tl.program_id(0)
off_h_q = tl.program_id(1)
off_z = tl.program_id(2)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
if VARLEN:
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
# small for all start_m so for those we return early.
if start_m * BLOCK_M > seqlen_q:
return
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
else:
cu_seqlens_q_start = 0
cu_seqlens_k_start = 0
seqlen_q = MAX_SEQLENS_Q
seqlen_k = MAX_SEQLENS_K
# Now we compute whether we need to exit early due to causal masking.
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
# are completely masked, resulting in 0s written to the output, and
# inf written to LSE. We don't need to do any GEMMs in this case.
# This block of code determines what N is, and if this WG is operating
# on those M rows.
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
if IS_CAUSAL:
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
# If seqlen_q != seqlen_k, attn scores are rectangular which means
# the causal mask boundary is bottom right aligned, and ends at either
# the top edge (seqlen_q < seqlen_k) or left edge.
# This captures the decrease in n_blocks if we have a rectangular attn
# matrix
n_blocks_seqlen = cdiv_fn(
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N
)
# This is what adjusts the block_max for the current WG, only
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
n_blocks = min(n_blocks, n_blocks_seqlen)
# If we have no blocks after adjusting for seqlen deltas, this WG is
# part of the blocks that are all 0. We exit early.
if n_blocks <= 0:
o_offset = (
off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
)
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(seqlen_q, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
# We still need to write 0s to the result
# tl.store(O_block_ptr,
# acc.to(Out.type.element_ty), boundary_check=(0,1))
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# + offs_m
# We store inf to LSE, not -inf because in the bwd pass,
# we subtract this
# from qk which makes it -inf, such that exp(qk - inf) = 0
# for these masked blocks.
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
# tl.store(l_ptrs, l)
# TODO: Should dropout and return encoded softmax be handled here?
return
# If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE: tl.constexpr = HQ // HK
if GROUP_SIZE != 1:
off_h_k = off_h_q // GROUP_SIZE
else:
off_h_k = off_h_q
n_extra_tokens = 0
if seqlen_k < BLOCK_N:
n_extra_tokens = BLOCK_N - seqlen_k
elif seqlen_k % BLOCK_N:
n_extra_tokens = seqlen_k % BLOCK_N
PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
# Compute pointers for all the tensors used in this kernel.
q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1),
)
v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)
if BIAS_TYPE != 0:
bias_ptr = tl.make_block_ptr(
base=bias + off_h_q * stride_bh,
shape=(seqlen_q, seqlen_k),
strides=(stride_bm, stride_bn),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
else:
bias_ptr = None
if ENABLE_DROPOUT:
batch_philox_offset = (
philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k
)
else:
batch_philox_offset = 0
# We can ask to return the dropout mask without actually doing any dropout.
# In this case, we return an invalid pointer so indicate the mask is not i
# valid.
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.make_block_ptr(
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
shape=(seqlen_q, seqlen_k),
strides=(seqlen_k, 1),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
else:
encoded_softmax_block_ptr = 0
# initialize pointer to m and l
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
# have native e^x support in HW.
qk_scale = sm_scale * 1.44269504089
# Q is loaded once at the beginning and shared by all N blocks.
q = load_fn(Q_block_ptr, True, PADDED_HEAD, "zero")
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
# Here we compute how many full and masked blocks we have.
padded_block_k = n_extra_tokens != 0
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
if IS_CAUSAL:
# There are always at least BLOCK_M // BLOCK_N masked blocks.
# Additionally there might be one more due to dissimilar seqlens.
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
else:
# Padding on Q does not need to be masked in the FA loop.
masked_blocks = padded_block_k
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional
# block. In this case we might exceed n_blocks so pick the min.
masked_blocks = min(masked_blocks, n_blocks)
n_full_blocks = n_blocks - masked_blocks
block_min = 0
block_max = n_blocks * BLOCK_N
# Compute for full blocks. Here we set causal to false regardless of its
# value because there is no masking. Similarly we do not need padding.
if n_full_blocks > 0:
block_max = (n_blocks - masked_blocks) * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
start_m,
seqlen_k,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min,
block_max,
0,
0,
0,
bias_ptr,
# IS_CAUSAL, ....
False,
BLOCK_M,
BLOCK_DMODEL,
BLOCK_N,
offs_m,
offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V,
False,
ENABLE_DROPOUT,
RETURN_ENCODED_SOFTMAX,
PADDED_HEAD,
)
block_min = block_max
block_max = n_blocks * BLOCK_N
tl.debug_barrier()
# Remaining blocks, if any, are full / not masked.
if masked_blocks > 0:
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
if bias_ptr is not None:
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.advance(
encoded_softmax_block_ptr, (0, n_full_blocks)
)
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
start_m,
seqlen_k,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
block_min,
block_max,
offs_n_causal,
masked_blocks,
n_extra_tokens,
bias_ptr,
IS_CAUSAL,
BLOCK_M,
BLOCK_DMODEL,
BLOCK_N,
offs_m,
offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V,
True,
ENABLE_DROPOUT,
RETURN_ENCODED_SOFTMAX,
PADDED_HEAD,
)
# epilogue
acc = acc / l_i[:, None]
if ENABLE_DROPOUT:
acc = acc / (1 - dropout_p)
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
# then we have one block with a row of all NaNs which come from computing
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
# and store 0s where there are NaNs as these rows should've been zeroed out.
end_m_idx = (start_m + 1) * BLOCK_M
start_m_idx = start_m * BLOCK_M
causal_start_idx = seqlen_q - seqlen_k
acc = acc.to(Out.type.element_ty)
if IS_CAUSAL: # noqa: SIM102
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
out_mask_boundary = tl.full(
(BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32
)
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :]
z = 0.0
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
# few rows. This is only true for the last M block. For others,
# overflow_size will be -ve
# overflow_size = end_m_idx - seqlen_q
# if overflow_size > 0:
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
# # This is a > check because mask being 0 blocks the store.
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
# else:
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
# Need boundary check on this to make sure the padding from the
# Q and KV tensors in both dims are not part of what we store back.
# TODO: Do the boundary check optionally.
tl.store(O_block_ptr, acc, boundary_check=(0, 1))
def check_args(
q,
k,
v,
o,
varlen=True,
max_seqlens=None,
cu_seqlens_q=None,
cu_seqlens_k=None,
):
assert q.dim() == k.dim() and q.dim() == v.dim()
if varlen:
assert q.dim() == 3
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
assert cu_seqlens_q is not None
assert cu_seqlens_k is not None
assert len(cu_seqlens_q) == len(cu_seqlens_k)
else:
assert q.dim() == 4
batch, nheads_q, seqlen_q, head_size = q.shape
_, nheads_k, seqlen_k, _ = k.shape
assert max_seqlens > 0
assert k.shape == v.shape
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
# TODO: Change assert if we support qkl f8 and v f16
assert q.dtype == k.dtype and q.dtype == v.dtype
# TODO: Fix assert to check head size <=256 once supported
assert head_size <= 128
assert o.shape == q.shape
assert (nheads_q % nheads_k) == 0
class _attention(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
o,
cu_seqlens_q,
cu_seqlens_k,
max_seqlens_q,
max_seqlens_k,
causal=False,
sm_scale=1.0,
bias=None,
):
if o is None:
o = torch.empty_like(q, dtype=v.dtype)
check_args(
q,
k,
v,
o,
varlen=True,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
)
if True: # varlen
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
batch = len(cu_seqlens_q) - 1
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
else:
batch, seqlen_q, nheads_q, head_size = q.shape
_, seqlen_k, nheads_k, _ = k.shape
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
# Get closest power of 2 over or equal to 32.
padded_d_model = 1 << (head_size - 1).bit_length()
padded_d_model = max(padded_d_model, 16)
def grid(META):
return triton.cdiv(max_seqlens_q, META["BLOCK_M"]), nheads_q, batch
encoded_softmax = None
# Seed the RNG so we get reproducible results for testing.
philox_seed = 0x1BF52
philox_offset = 0x1D4B42
if bias is not None:
bias_strides = (
bias.stride(0),
bias.stride(1),
bias.stride(2),
bias.stride(3),
)
else:
bias_strides = (0, 0, 0, 0)
attn_fwd[grid](
q,
k,
v,
bias,
sm_scale,
None,
o,
*q_strides,
*k_strides,
*v_strides,
*o_strides,
*bias_strides,
cu_seqlens_q,
cu_seqlens_k,
dropout_p=0.0,
philox_seed=philox_seed,
philox_offset_base=philox_offset,
encoded_softmax=encoded_softmax,
HQ=nheads_q,
HK=nheads_k,
ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k,
IS_CAUSAL=causal,
VARLEN=True,
BLOCK_DMODEL=padded_d_model,
BIAS_TYPE=0 if bias is None else 1,
ENABLE_DROPOUT=False,
RETURN_ENCODED_SOFTMAX=False,
)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = head_size
ctx.causal = causal
ctx.dropout_p = 0.0
ctx.philox_seed = philox_seed
ctx.philox_offset = philox_offset
ctx.encoded_softmax = encoded_softmax
ctx.return_encoded_softmax = False
return o, encoded_softmax
triton_attention = _attention.apply

View File

@ -0,0 +1,251 @@
from typing import Optional
from contextvars import ContextVar
from contextlib import contextmanager
import flashinfer
import torch
prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar(
"prefill_state"
)
prefill_with_paged_kv_state: ContextVar[
flashinfer.BatchPrefillWithPagedKVCacheWrapper
] = ContextVar("prefill_with_paged_kv_state")
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
"decode_state"
)
workspace: Optional[torch.Tensor] = None
def get_workspace(device):
"""Get shared flashinfer workspace."""
global workspace
if workspace is None:
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
return workspace
def create_prefill_with_paged_kv_state(
*,
device: torch.device,
):
"""Create a prefill state that uses the KV cache."""
workspace_buffer = get_workspace(device)
return flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
)
@contextmanager
def use_prefill_with_paged_kv_state(
*,
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
block_tables: torch.Tensor,
cu_seqlens: torch.Tensor,
input_lengths: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
page_size: int,
dtype: torch.dtype,
window_left: int,
):
"""
Context manager to set the active flashinfer prefill state to the given
`state` and parameters. This state will be used by all calls to the
`attention` function while the context manager is active.
"""
indptr = torch.zeros(
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
)
# Round up to page size and then calculate the cumulative sum to get
# the indices into the block table.
torch.add(input_lengths, page_size - 1, out=indptr[1:])
indptr[1:].div_(page_size, rounding_mode="floor")
indptr[1:].cumsum_(-1)
# Get the lengths of the last page in a block.
if page_size == 1:
last_page_len = torch.ones(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
else:
last_page_len = torch.empty(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
torch.sub(input_lengths, 1, out=last_page_len)
last_page_len.remainder_(page_size)
last_page_len += 1
token = prefill_with_paged_kv_state.set(state)
try:
state.begin_forward(
qo_indptr=cu_seqlens,
paged_kv_indptr=indptr,
paged_kv_indices=block_tables,
paged_kv_last_page_len=last_page_len,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=dtype,
page_size=page_size,
window_left=window_left,
)
yield
finally:
state.end_forward()
if token is not None:
prefill_with_paged_kv_state.reset(token)
def create_prefill_state(
*,
device: torch.device,
):
"""Create a prefill state."""
workspace_buffer = get_workspace(device)
return flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
)
@contextmanager
def use_prefill_state(
*,
state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper,
cu_seqlens: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
window_left: int,
):
"""
Context manager to set the active flashinfer prefill state to the given
`state` and parameters. This state will be used by all calls to the
`attention` function while the context manager is active.
"""
token = prefill_state.set(state)
try:
state.begin_forward(
qo_indptr=cu_seqlens,
kv_indptr=cu_seqlens,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=dtype,
window_left=window_left,
)
yield
finally:
state.end_forward()
if token is not None:
prefill_state.reset(token)
def create_decode_state(
*,
device: torch.device,
num_heads: int,
num_kv_heads: int,
):
"""Create a decode state."""
workspace_buffer = get_workspace(device)
num_groups = num_heads // num_kv_heads
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
use_cuda_graph=False,
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
use_tensor_cores=num_groups not in [1, 2, 4, 8],
)
def create_decode_state_cuda_graphs(
*,
device: torch.device,
block_tables: torch.Tensor,
block_tables_ptr: torch.Tensor,
last_page_len: torch.Tensor,
num_heads: int,
num_kv_heads: int,
):
"""
Create a decode state for use with CUDA Graphs. `block_tables`,
`block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are
therefore stored as part of the state.
"""
workspace_buffer = get_workspace(device)
num_groups = num_heads // num_kv_heads
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
use_cuda_graph=True,
paged_kv_indices_buffer=block_tables,
paged_kv_indptr_buffer=block_tables_ptr,
paged_kv_last_page_len_buffer=last_page_len,
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
use_tensor_cores=num_groups not in [1, 2, 4, 8],
)
@contextmanager
def use_decode_state(
*,
state: flashinfer.BatchDecodeWithPagedKVCacheWrapper,
input_lengths: torch.Tensor,
block_tables: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
page_size: int,
dtype: torch.dtype,
window_left: int,
):
"""
Context manager to set the active flashinfer decoding state to the given
`state` and parameters. This state will be used by all calls to the
`paged_attention` function while the context manager is active.
"""
indptr = torch.zeros(
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
)
# Round up to page size and then calculate the cumulative sum to get
# the indices into the block table.
torch.add(input_lengths, page_size - 1, out=indptr[1:])
indptr[1:].div_(page_size, rounding_mode="floor")
indptr[1:].cumsum_(-1)
# Get the lengths of the last page in a block.
last_page_len = torch.empty(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
torch.sub(input_lengths, 1, out=last_page_len)
last_page_len.remainder_(page_size)
last_page_len += 1
token = decode_state.set(state)
try:
state.begin_forward(
indptr=indptr,
indices=block_tables,
last_page_len=last_page_len,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
page_size=page_size,
data_type=dtype,
q_data_type=dtype,
window_left=window_left,
)
yield
finally:
state.end_forward()
if token is not None:
decode_state.reset(token)

View File

@ -1,95 +0,0 @@
import torch
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
from typing import Optional
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from vllm_hpu_extension import ops
from vllm_hpu_extension.utils import Matmul
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA
import os
SUPPORTS_WINDOWING = False
def fetch_from_cache(cache, blocks):
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
return cache[: blocks.size(0)]
else:
return cache.index_select(0, blocks)
def attention(
*,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
):
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
bs = seqlen.input_lengths.shape[0]
_, head_num, head_size = query.shape
_, kv_head_num, head_size = key.shape
query = query.view(bs, -1, head_num, head_size).transpose(1, 2)
key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
attn_output = fsdpa_op(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=causal,
scale=softmax_scale,
softmax_mode="None",
recompute_mode=None,
valid_sequence_lengths=seqlen.input_lengths,
padding_side="left",
)
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
return attn_output
def paged_attention(
query: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
seqlen: Seqlen,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
hpu_attention_meta: HPUPagedAttentionMetadata,
):
batch_size, head_num, head_size = query.shape
output = ops.flat_pa(
query=query.view(batch_size, 1, head_num * head_size),
key_cache=kv_cache.key,
value_cache=kv_cache.value,
block_list=hpu_attention_meta.block_list,
block_mapping=hpu_attention_meta.block_mapping,
block_bias=hpu_attention_meta.attn_bias,
block_scales=hpu_attention_meta.block_scales,
block_groups=hpu_attention_meta.block_groups,
scale=softmax_scale,
matmul_qk_op=Matmul(),
matmul_av_op=Matmul(),
batch2block_matmul_op=Matmul(),
block2batch_matmul_op=Matmul(),
keys_fetch_func=fetch_from_cache,
values_fetch_func=fetch_from_cache,
)
# Reshape the output tensor.
return output.view(batch_size, head_num, head_size)
__all__ = [
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
]

View File

@ -0,0 +1,82 @@
import intel_extension_for_pytorch as ipex
import torch
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen
from typing import Optional
SUPPORTS_WINDOWING = False
PREFILL_IN_KV_CACHE = False
def attention(
q: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap: Optional[float] = None,
):
out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
ipex.llm.functional.varlen_attention(
q.contiguous() if q.device.type == "xpu" else q,
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_q,
0.0,
softmax_scale,
False,
causal,
False,
None,
)
return out
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
softcap: Optional[float] = None,
):
out = torch.empty_like(query)
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
seqlen.input_lengths,
BLOCK_SIZE,
max_s,
None,
)
return out

View File

@ -1,139 +0,0 @@
from typing import Tuple
from dataclasses import dataclass, field
import torch
from text_generation_server.models.globals import BLOCK_SIZE
from text_generation_server.utils.weights import Weights
from vllm_hpu_extension import cache_ops
@dataclass
class KVScales:
"""
Key-value scales for FP8 KV cache.
This data class stores key and value scales both as a GPU tensor and
as a GPU float. This inconvenience is necessary because some functions
(e.g. scaling kernels) take scales as a GPU tensor, whereas others
(e.g. flashinfer) take scales as a CPU scalar.
"""
key_scale: torch.Tensor
value_scale: torch.Tensor
key_scale_cpu: float = field(init=False)
value_scale_cpu: float = field(init=False)
def __post_init__(self):
if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:
raise ValueError("Key and value scales must be scalar tensors.")
self.key_scale_cpu = self.key_scale.item()
self.value_scale_cpu = self.value_scale.item()
class KVCache:
"""
Key-value cache for attention layers.
"""
kv_cache: Tuple[torch.Tensor, torch.Tensor]
def __init__(
self,
*,
num_blocks: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
"""Construct the key-value cache for a layer."""
## TODO FP8 kv cache support
self.kv_cache = (
torch.zeros(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
torch.zeros(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
)
@property
def dtype(self):
"""Get the data type of the cache."""
return self.kv_cache[0].dtype
@property
def key(self):
"""Get the key cache."""
return self.kv_cache[0]
@property
def value(self):
"""Get the value cache."""
return self.kv_cache[1]
def store(
self,
*,
key: torch.Tensor,
value: torch.Tensor,
slots: torch.Tensor,
kv_scales: KVScales,
):
"""Store the key and value at the given slots."""
## TODO FP8 kv cache support
key_cache = self.kv_cache[0]
value_cache = self.kv_cache[1]
paged_reshape_and_cache(
key,
value,
key_cache,
value_cache,
slots,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
)
def paged_reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
):
block_idx = slots // BLOCK_SIZE
block_offset = slots % BLOCK_SIZE
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
"""Load KV cache scales."""
key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)
value_scale = key_scale
if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor(
f"{prefix}.v_scale"
):
key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float()
value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float()
elif weights.has_tensor(f"{prefix}.kv_scale"):
# Fall back to older more coarse-grained scale when available.
key_scale = weights.get_tensor(f"{prefix}.kv_scale").float()
value_scale = key_scale
return KVScales(key_scale=key_scale, value_scale=value_scale)

View File

@ -0,0 +1,308 @@
import os
from typing import Optional
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils.log import log_master
from loguru import logger
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE_V1V2 = 512
_PARTITION_SIZE_CUSTOM = 256
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck"
PREFILL_IN_KV_CACHE = False
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
try:
if use_rocm_custom_paged_attn:
from vllm._custom_C import paged_attention_custom
except ImportError as e:
log_master(
logger.info,
f"Custom Paged Attention not available. Complete error: {e}",
)
use_rocm_custom_paged_attn = False
try:
import vllm._custom_ops as ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if ATTENTION == "flashdecoding":
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
softcap: Optional[float] = None,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping")
# value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
num_kv_heads = key_cache.shape[1]
gqa_ratio = num_heads // num_kv_heads
use_custom = (
use_rocm_custom_paged_attn
and (query.dtype == torch.half or query.dtype == torch.bfloat16)
and (head_size == 128 or head_size == 64)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_s <= 32768
)
if not use_custom:
_PARTITION_SIZE = _PARTITION_SIZE_V1V2
else:
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = seqlen.input_lengths
out = torch.empty_like(query)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
import vllm._custom_ops as ops
use_v1 = (
max_s <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512)
and not use_custom
)
if use_v1:
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
if not use_custom:
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
else:
paged_attention_custom(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
)
return out
if ENGINE != "triton":
try:
import flash_attn_2_cuda
log_master(
logger.info,
"ROCm: using Flash Attention 2 Composable Kernel implementation.",
)
except ImportError as e:
if major >= 8:
architecture_suffix = f"-{SYSTEM}"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
elif is_sm75:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
else:
for idx in range(torch.cuda.device_count()):
name = torch.cuda.get_device_name(idx)
if "MI210" not in name and "MI250" not in name:
raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
)
raise ImportError(
f"AMD GPU with ROCm capability {major} {minor} is not supported"
) from e
SUPPORTS_WINDOWING = False
if ENGINE == "ck":
def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: float = 0.0,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
None,
None,
None,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
softcap,
False,
None,
)[0]
elif ENGINE == "triton":
from .flash_attn_triton import triton_attention
def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
):
if softcap is not None:
raise NotImplementedError("softcap is only available with CK flash attn")
out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
output, _ = triton_attention(
q,
key_cache,
value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
causal,
softmax_scale,
)
return output
else:
raise RuntimeError(f"Unknown attention engine {ENGINE}")

View File

@ -1,3 +0,0 @@
from .hpu import WQLinear
__all__ = ["WQLinear"]

View File

@ -1,134 +0,0 @@
from typing import Optional
import torch
import torch.nn as nn
try:
import habana_frameworks.torch.hpu # noqa: F401
convert_from_uint4 = torch.ops.hpu.convert_from_uint4
except Exception as e:
hpu_import_exception = e
def error_raiser_hpu(*args, **kwargs):
raise ValueError(
f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
)
convert_from_uint4 = error_raiser_hpu
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qzeros.device)
# unpacking columnwise
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
torch.int8 # smallest dtype available
)
iweights = iweights.view(iweights.shape[0], -1)
# unpacking columnwise
if qzeros is not None:
izeros = torch.bitwise_right_shift(
qzeros[:, :, None], shifts[None, None, :]
).to(
torch.int8 # smallest dtype available
)
izeros = izeros.view(izeros.shape[0], -1)
else:
izeros = qzeros
return iweights, izeros
def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
reverse_order_tensor = torch.arange(
iweights.shape[-1],
dtype=torch.int32,
device=izeros.device,
)
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
reverse_order_tensor = reverse_order_tensor.view(-1)
if izeros is not None:
izeros = izeros[:, reverse_order_tensor]
iweights = iweights[:, reverse_order_tensor]
return iweights, izeros
def unpack_weight_and_zeros(qweight, qzeros, bits):
# Unpack the qweight and qzeros tensors
iweight, izeros = unpack_awq(qweight, qzeros, bits)
# Reverse the order of the iweight and izeros tensors
iweight, izeros = reverse_awq_order(iweight, izeros, bits)
# overflow checks
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
return iweight, izeros
def pack_tensor(input, bits=4):
normal = input.to(torch.int32)
q = torch.zeros(
(normal.shape[0], normal.shape[1] // 32 * bits),
dtype=torch.int32,
device=input.device,
)
i = 0
col = 0
while col < q.shape[1]:
for j in range(i, i + (32 // bits)):
q[:, col] |= normal[:, j] << (bits * (j - i))
i += 32 // bits
col += 1
q = q.to(torch.int32)
return q
class WQLinear(nn.Module):
def __init__(
self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.in_features = qweight.shape[0]
self.out_features = qweight.shape[1] * 32 // w_bit
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else self.in_features
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert self.out_features % (32 // self.w_bit) == 0
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.bias = bias
self._preprocessing()
def _preprocessing(self):
device = self.qweight.device
weight, zeros = unpack_weight_and_zeros(
self.qweight.cpu(), self.qzeros.cpu(), self.w_bit
)
self.qweight = pack_tensor(weight).to(device)
self.qzeros = pack_tensor(zeros).to(device)
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,)
x = x.reshape(-1, x.shape[-1])
weights = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
outputs = torch.matmul(x, weights)
outputs = outputs + self.bias if self.bias is not None else outputs
outputs = outputs.reshape(out_shape)
return outputs

View File

@ -0,0 +1,49 @@
# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
from typing import Optional
import torch
import torch.nn as nn
import awq_inference_engine # with CUDA kernels
# class ScaledActivation(nn.Module):
# def __init__(self, module, scales):
# super().__init__()
# self.act = module
# self.scales = nn.Parameter(scales.data)
#
# def forward(self, x):
# return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
class WQLinear(nn.Module):
def __init__(
self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.in_features = qweight.shape[0]
self.out_features = qweight.shape[1] * 32 // w_bit
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else self.in_features
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert self.out_features % (32 // self.w_bit) == 0
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.bias = bias
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,)
out = awq_inference_engine.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)

View File

@ -0,0 +1,43 @@
from dataclasses import dataclass
import torch
from EETQ import quant_weights, w8_a16_gemm
from text_generation_server.utils.weights import UnquantizedWeight
@dataclass
class EETQWeight(UnquantizedWeight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
try:
from text_generation_server.layers.eetq import EETQLinear
return EETQLinear(self.weight, bias)
except ImportError:
raise ImportError(
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
)
class EETQLinear(torch.nn.Module):
def __init__(
self,
weight,
bias,
) -> None:
super().__init__()
device = weight.device
if weight.dtype != torch.float16:
weight = weight.to(dtype=torch.float16)
weight = torch.t(weight).contiguous().cpu()
weight, scale = quant_weights(weight, torch.int8, False)
self.weight = weight.cuda(device)
self.scale = scale.cuda(device)
self.bias = bias.cuda(device) if bias is not None else None
def forward(self, input: torch.Tensor) -> torch.Tensor:
output = w8_a16_gemm(input, self.weight, self.scale)
output = output + self.bias if self.bias is not None else output
return output

View File

@ -1,152 +1,100 @@
from dataclasses import dataclass
from typing import Optional, Tuple, Type, Union, List
import torch
from dataclasses import dataclass
from typing import Optional, Union, List
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import (
Weight,
WeightsLoader,
UnquantizedWeight,
Weights,
)
from vllm_hpu_extension.ops import scaled_fp8_quant
from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2
import habana_frameworks.torch.utils.experimental as htexp
w8a8_block_fp8_matmul = None
per_token_group_quant_fp8 = None
quant_dtype: torch.dtype = torch.float8_e4m3fn
from text_generation_server.utils.log import log_master, log_once
import importlib.util
def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
FBGEMM_MM_AVAILABLE = False
FBGEMM_DYN_AVAILABLE = False
def is_fbgemm_gpu_available():
try:
return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
except ModuleNotFoundError:
return False
if is_fbgemm_gpu_available():
if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability()
FBGEMM_MM_AVAILABLE = major == 9
FBGEMM_DYN_AVAILABLE = major >= 8
else:
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
def get_fp8_linear() -> torch.nn.Module:
"""
Return an FP8 linear `Module` that is compatible with the current system.
"""
if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability()
if major == 8:
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
return GPTQMarlinFP8Linear
# On other systems let Torch decide if the hardware supports FP8.
return Fp8Linear
def normalize_e4m3fn_to_native_float8(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return weight, weight_scale, input_scale
def per_tensor_dequantize(
tensor: torch.Tensor,
inv_scale: Union[float, torch.Tensor],
dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
device = tensor.device
dtype = torch.bfloat16
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2:
# dequant on cpu to avoid nan on gaudi2
tensor = tensor.to("cpu")
fake_qweight = tensor.to(dtype).to(device)
dq_weight = fake_qweight * inv_scale
return dq_weight
def requantize_with_max_scale(
weight: torch.Tensor,
weight_scale: torch.Tensor,
logical_widths: int,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max()
if is_hpu_gaudi2():
max_w_scale = max_w_scale * get_hpu_gaudi2_scale_factor()
start = 0
for idx, logical_width in enumerate(logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(
weight[start:end, :], weight_scale[idx], dtype
)
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
weight_dq, max_w_scale
)
start = end
return weight, max_w_scale_normalized
def fp8_quantize(
weight: torch.Tensor,
scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[torch.Tensor] = None,
qdtype: torch.dtype = torch.float8_e4m3fn,
scalar: bool = False,
weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
):
"""
This function returns a reciprocal of the scale, so that a tensor can be unscaled
by multiplying it with the returned scale. If a scale is given through the `scale`
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
be used without modification).
"""
shape = weight.shape
qweight, scale = scaled_fp8_quant(
weight.reshape(-1, shape[-1]),
scale=scale,
scale_ub=scale_upper_bound,
# TODO: don't do this when we have to use the Torch kernel.
use_per_token_if_dynamic=not scalar,
)
if FBGEMM_DYN_AVAILABLE and not scalar:
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
)
return qweight, scale
return qweight.reshape(shape), scale
# weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(qdtype)
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(qdtype)
scale = scale.float().reciprocal()
return qweight, scale
class HybridFP8UnquantLoader(WeightsLoader):
"""Weight loader that loads FP8 and unquantized Torch tensors."""
def __init__(
self,
activation_scale_ub: Optional[float],
to_fp8: bool,
weight_block_size: Optional[List[int]] = None,
):
def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
self.activation_scale_ub = activation_scale_ub
self.to_fp8 = to_fp8
self.weight_block_size = weight_block_size
def get_weights(self, weights: "Weights", prefix: str):
w = weights.get_tensor(f"{prefix}.weight")
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
scale = weights.get_tensor(f"{prefix}.weight_scale_inv")
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
# FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = (
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
.reshape(-1)
.max()
)
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(0), logical_widths, weights.dtype
scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
@ -168,7 +116,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if scale.numel() > 1:
scale = weights.get_packed_sharded(
f"{prefix}.weight_scale",
@ -176,29 +123,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
block_sizes=block_sizes,
to_dtype=False,
)
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
)
if input_scale.numel() > 1:
input_scale = weights.get_packed_sharded(
f"{prefix}.input_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
input_scale = input_scale.reshape(-1).max()
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(0), logical_widths, weights.dtype
)
scale = scale.reshape(-1).expand(w.shape[0])
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
@ -219,48 +148,15 @@ class HybridFP8UnquantLoader(WeightsLoader):
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
scale = [
weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False)
for p in prefixes
]
scale = torch.cat(scale, dim=dim)
scale = scale.to(weights.device)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p, shape in zip(prefixes, shapes)
]
scale = torch.cat(scale, dim=0).reshape(-1)
input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.to(weights.device), logical_widths, weights.dtype
)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
@ -273,35 +169,14 @@ class HybridFP8UnquantLoader(WeightsLoader):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
# XXX: Yes the weights is named scale_inv, but corresponds to scale it seems.
scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = (
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
.reshape(-1)
.max()
)
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(0), logical_widths, weights.dtype
scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
@ -316,126 +191,83 @@ class Fp8Weight(Weight):
weight: torch.Tensor
dtype: torch.dtype
weight_scale: Optional[torch.Tensor] = None
input_scale: Optional[torch.Tensor] = None
activation_scale_ub: Optional[float] = None
force_w8a16: bool = False
weight_block_size: Optional[List[int]] = None
def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None:
return get_fp8_linear(force_w8a16=self.force_w8a16).from_unquant(
self.weight, bias, self.dtype
)
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
# This is not checked by the fbgemm kernels, but they require contiguous
# memory. Can be non-contiguous when we e.g. expand from scalars.
self.weight_scale = self.weight_scale.contiguous()
return get_fp8_linear(force_w8a16=self.force_w8a16).from_fp8(
weight=self.weight,
scale=self.weight_scale,
dtype=self.dtype,
bias=bias,
input_scale=self.input_scale,
scale_upper_bound=self.activation_scale_ub,
weight_block_size=self.weight_block_size,
return get_fp8_linear().from_fp8(
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
)
class Fp8Linear(torch.nn.Module):
_device_identity_cache = {}
def __init__(
self,
qweight: torch.Tensor,
scale: torch.Tensor,
dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[float] = None,
weight_block_size: Optional[List[int]] = None,
qweight,
scale,
scale_upper_bound,
bias,
dtype,
) -> None:
super().__init__()
if FBGEMM_MM_AVAILABLE:
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
self.dtype = dtype
self.qweight = qweight
self.scale = scale.float()
self.input_scale = input_scale.float() if input_scale is not None else None
self.weight_block_size = weight_block_size
self.scale_upper_bound = scale_upper_bound
self.scale = scale
self.scale_upper_bound = (
torch.tensor(
[scale_upper_bound], dtype=torch.float32, device=qweight.device
)
if scale_upper_bound is not None
else None
)
self.bias = bias if bias is not None else None
@classmethod
def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight, scalar=True)
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
return cls(
qweight=qweight,
scale=scale,
dtype=dtype,
bias=bias,
input_scale=None,
scale_upper_bound=None,
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
)
@classmethod
def from_fp8(
cls,
weight: torch.Tensor,
scale: torch.Tensor,
dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
**kwargs,
) -> "Fp8Linear":
input_scale = kwargs.get("input_scale", None)
scale_upper_bound = kwargs.get("scale_upper_bound", None)
weight_block_size = kwargs.get("weight_block_size", None)
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
if FBGEMM_DYN_AVAILABLE:
# fbgemm needs float32 scales.
scale = scale.float()
return cls(
qweight=weight,
scale=scale,
input_scale=input_scale,
scale_upper_bound=scale_upper_bound,
scale_upper_bound=input_scale,
bias=bias,
dtype=dtype,
weight_block_size=weight_block_size,
)
@classmethod
def get_shared_device_identity(cls, device):
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
if device not in cls._device_identity_cache:
cls._device_identity_cache[device] = torch.ones(1, device=device)
return cls._device_identity_cache[device]
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.weight_block_size is not None:
# https://arxiv.org/pdf/2412.19437
# At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and
# scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
# group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
# channels).
qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
output = w8a8_block_fp8_matmul(
if FBGEMM_MM_AVAILABLE:
qinput, scale = fp8_quantize(
input, scale_upper_bound=self.scale_upper_bound
)
y = torch.ops.fbgemm.f8f8bf16_rowwise(
qinput,
self.qweight,
scale,
self.scale,
self.weight_block_size,
output_dtype=input.dtype,
use_fast_accum=True,
bias=self.bias,
)
return y.to(self.dtype)
if self.bias is not None:
output = output + self.bias
return output.to(dtype=input.dtype)
qinput, scale = fp8_quantize(
input,
self.input_scale,
scale_upper_bound=self.scale_upper_bound,
scalar=True,
)
output = torch._scaled_mm(
qinput, scale = fp8_quantize(input, scalar=True)
output, _ = torch._scaled_mm(
qinput,
self.qweight.t(),
out_dtype=self.dtype,
@ -443,16 +275,11 @@ class Fp8Linear(torch.nn.Module):
scale_b=self.scale,
bias=self.bias,
)
if isinstance(output, tuple) and len(output) == 2:
output = output[0]
return output
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
scale = weights.get_tensor(prefix, to_dtype=False)
if scale.numel() > 1:
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
return scale.reshape(-1)
return scale.reshape(-1).expand(shape[0])

View File

@ -1,15 +1,14 @@
import os
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
from .hpu import QuantLinear
@dataclass
class GPTQWeight(Weight):
qweight: torch.Tensor
@ -31,8 +30,13 @@ class GPTQWeight(Weight):
def get_linear(self, bias: torch.Tensor):
if self.use_awq_kernel:
if SYSTEM == "rocm":
raise NotImplementedError(
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
"to use Exllama/GPTQ kernels for AWQ inference."
)
try:
from text_generation_server.layers.awq.quantize import WQLinear
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
return WQLinear(
w_bit=self.bits,
@ -46,7 +50,18 @@ class GPTQWeight(Weight):
raise NotImplementedError(
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
)
elif self.use_exllama:
try:
from text_generation_server.layers.gptq import ExllamaQuantLinear
except ImportError:
raise NotImplementedError(
"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
)
return ExllamaQuantLinear(self, bias)
else:
from text_generation_server.layers.gptq.quant_linear import QuantLinear
return QuantLinear(
self.qweight,
self.qzeros,
@ -103,6 +118,23 @@ class GPTQWeightsLoader(WeightsLoader):
else:
g_idx = None
from text_generation_server.layers.gptq import (
HAS_EXLLAMA,
CAN_EXLLAMA,
GPTQWeight,
)
if use_exllama:
if not HAS_EXLLAMA:
if CAN_EXLLAMA:
log_once(
logger.warning,
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
)
use_exllama = False
else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
qzeros = weights.get_tensor(f"{prefix}.qzeros")
scales = weights.get_tensor(f"{prefix}.scales")
@ -215,7 +247,14 @@ class GPTQWeightsLoader(WeightsLoader):
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act
from text_generation_server.layers.gptq import HAS_EXLLAMA
use_exllama = (
self.bits == 4
and HAS_EXLLAMA
and self.quantize == "gptq"
and not self.desc_act
)
if self.quantize == "gptq" and self.quant_method == "gptq":
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
@ -259,7 +298,6 @@ class GPTQWeightsLoader(WeightsLoader):
self._get_gptq_params(weights)
use_exllama = True
desc_act = self.desc_act
if self.bits != 4:
use_exllama = False
@ -283,8 +321,7 @@ class GPTQWeightsLoader(WeightsLoader):
if g_idx is not None:
if (
not torch.equal(
# Remove g_idx[0] to adapt the check with TP>1.
(g_idx - g_idx[0]).cpu(),
g_idx.cpu(),
torch.tensor(
[i // self.groupsize for i in range(g_idx.shape[0])],
dtype=torch.int32,
@ -295,22 +332,34 @@ class GPTQWeightsLoader(WeightsLoader):
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
use_exllama = False
desc_act = True
from text_generation_server.layers.gptq import (
CAN_EXLLAMA,
HAS_EXLLAMA,
GPTQWeight,
)
if not desc_act and self.groupsize != -1:
if use_exllama:
if not HAS_EXLLAMA:
if CAN_EXLLAMA:
log_once(
logger.warning,
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
)
use_exllama = False
else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
if use_exllama and self.groupsize != -1:
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
if g_idx is not None:
# qzeros, scales sharded, and g_idx must be adjusted accordingly
g_idx = g_idx - g_idx[0]
else:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
scales = weights.get_tensor(f"{prefix}.scales")
if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0]
if self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
@ -343,7 +392,7 @@ class GPTQWeightsLoader(WeightsLoader):
)
def _get_gptq_params(self, weights: Weights):
if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
self.bits = weights.get_tensor("gptq_bits").item()
self.groupsize = weights.get_tensor("gptq_groupsize").item()
self.desc_act = False
@ -351,7 +400,41 @@ class GPTQWeightsLoader(WeightsLoader):
# before the `gptq_sym` setting tensor was added.
self.sym = (
weights.get_tensor("gptq_sym").item()
if weights.has_tensor("gptq_sym")
if weights._has_tensor("gptq_sym")
else False
)
self.quant_method = "gptq"
# Needs to be at the end because circular import.
try:
major, _minor = torch.cuda.get_device_capability()
except Exception:
major = 1
HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm"
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False
elif CAN_EXLLAMA:
try:
if V2:
from text_generation_server.layers.gptq.exllamav2 import (
QuantLinear as ExllamaQuantLinear, # noqa: F401
create_exllama_buffers, # noqa: F401
set_device, # noqa: F401
)
HAS_EXLLAMA = "2"
else:
from text_generation_server.layers.gptq.exllama import (
Ex4bitLinear as ExllamaQuantLinear, # noqa: F401
create_exllama_buffers, # noqa: F401
set_device, # noqa: F401
)
HAS_EXLLAMA = "1"
except ImportError:
pass

View File

@ -0,0 +1,261 @@
# https://github.com/fpgaminer/GPTQ-triton
"""
Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.
"""
import builtins
import math
import time
from typing import Dict
import triton
class Autotuner(triton.KernelInterface):
def __init__(
self,
fn,
arg_names,
configs,
key,
reset_to_zero,
prune_configs_by: Dict = None,
nearest_power_of_two: bool = False,
):
"""
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results
"""
if not configs:
self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
else:
self.configs = configs
self.key_idx = [arg_names.index(k) for k in key]
self.nearest_power_of_two = nearest_power_of_two
self.cache = {}
# hook to reset all required tensor to zeros before relaunching a kernel
self.hook = lambda args: 0
if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
def _hook(args):
for i in self.reset_idx:
args[i].zero_()
self.hook = _hook
self.arg_names = arg_names
# prune configs
if prune_configs_by:
perf_model, top_k = (
prune_configs_by["perf_model"],
prune_configs_by["top_k"],
)
if "early_config_prune" in prune_configs_by:
early_config_prune = prune_configs_by["early_config_prune"]
else:
perf_model, top_k, early_config_prune = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.early_config_prune = early_config_prune
self.fn = fn
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
# as kwargs and by the autotuner
conflicts = meta.keys() & config.kwargs.keys()
if conflicts:
raise ValueError(
f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols."
)
# augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs)
def kernel_call():
if config.pre_hook:
config.pre_hook(self.nargs)
self.hook(args)
self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**current,
)
try:
# In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
# PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
return triton.testing.do_bench(
kernel_call, quantiles=(0.5, 0.2, 0.8), rep=40
)
except triton.OutOfResources:
return [float("inf"), float("inf"), float("inf")]
def run(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
if len(self.configs) > 1:
key = tuple(args[i] for i in self.key_idx)
# This reduces the amount of autotuning by rounding the keys to the nearest power of two
# In my testing this gives decent results, and greatly reduces the amount of tuning required
if self.nearest_power_of_two:
key = tuple([2 ** int(math.log2(x) + 0.5) for x in key])
if key not in self.cache:
# prune configs
pruned_configs = self.prune_configs(kwargs)
bench_start = time.time()
timings = {
config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs
}
bench_end = time.time()
self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
self.configs_timings = timings
config = self.cache[key]
else:
config = self.configs[0]
self.best_config = config
if config.pre_hook is not None:
config.pre_hook(self.nargs)
return self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs,
)
def prune_configs(self, kwargs):
pruned_configs = self.configs
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {
config: self.perf_model(
**self.nargs,
**kwargs,
**config.kwargs,
num_stages=config.num_stages,
num_warps=config.num_warps,
)
for config in pruned_configs
}
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[
:top_k
]
return pruned_configs
def warmup(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
for config in self.prune_configs(kwargs):
self.fn.warmup(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs,
)
self.nargs = None
def autotune(
configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False
):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
.. highlight:: python
.. code-block:: python
@triton.autotune(configs=[
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
],
key=['x_size'] # the two above configs will be evaluated anytime
# the value of x_size changes
)
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
:note: When all the configurations are evaluated, the kernel will run multiple time.
This means that whatever value the kernel updates will be updated multiple times.
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
reset the value of the provided tensor to `zero` before running any configuration.
:param configs: a list of :code:`triton.Config` objects
:type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
:type key: list[str]
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
"""
def decorator(fn):
return Autotuner(
fn,
fn.arg_names,
configs,
key,
reset_to_zero,
prune_configs_by,
nearest_power_of_two,
)
return decorator
def matmul248_kernel_config_pruner(configs, nargs):
"""
The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.
"""
m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16)
n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16)
k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16)
used = set()
for config in configs:
block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"])
block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"])
block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"])
group_size_m = config.kwargs["GROUP_SIZE_M"]
if (
block_size_m,
block_size_n,
block_size_k,
group_size_m,
config.num_stages,
config.num_warps,
) in used:
continue
used.add(
(
block_size_m,
block_size_n,
block_size_k,
group_size_m,
config.num_stages,
config.num_warps,
)
)
yield triton.Config(
{
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
},
num_stages=config.num_stages,
num_warps=config.num_warps,
)

View File

@ -0,0 +1,134 @@
from text_generation_server.layers.gptq import GPTQWeight
import torch
from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
"""Construct Q4Matrix, return handle"""
return make_q4(
qweight, qzeros, scales, g_idx if g_idx is not None else none_tensor, device
)
def ext_q4_matmul(x, q4, q4_width):
"""Matrix multiplication, returns x @ q4"""
outshape = x.shape[:-1] + (q4_width,)
x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], q4_width), dtype=torch.float16, device=x.device)
q4_matmul(x, q4, output)
return output.view(outshape)
MAX_DQ = 1
MAX_INNER = 1
ACT_ORDER = False
DEVICE = None
TEMP_STATE = None
TEMP_DQ = None
def set_device(device):
global DEVICE
DEVICE = device
def create_exllama_buffers(max_total_tokens: int):
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ
assert DEVICE is not None, "call set_device first"
if not ACT_ORDER:
max_total_tokens = 1
# This temp_state buffer is required to reorder X in the act-order case.
temp_state = torch.zeros(
(max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE
)
temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE)
# This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
prepare_buffers(DEVICE, temp_state, temp_dq)
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
TEMP_STATE, TEMP_DQ = temp_state, temp_dq
class Ex4bitLinear(torch.nn.Module):
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(self, weight: GPTQWeight, bias):
super().__init__()
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE
assert weight.bits == 4
self.device = weight.qweight.device
self.qweight = weight.qweight
self.qzeros = weight.qzeros
self.scales = weight.scales
self.g_idx = weight.g_idx.cpu() if weight.g_idx is not None else None
self.bias = bias if bias is not None else None
if self.g_idx is not None and (
(self.g_idx == 0).all()
or torch.equal(
weight.g_idx.cpu(),
torch.tensor(
[i // weight.groupsize for i in range(weight.g_idx.shape[0])],
dtype=torch.int32,
),
)
):
self.empty_g_idx = True
self.g_idx = None
assert self.device.type == "cuda"
assert self.device.index is not None
self.q4 = ext_make_q4(
self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index
)
self.height = weight.qweight.shape[0] * 8
self.width = weight.qweight.shape[1]
# Infer groupsize from height of qzeros
self.groupsize = None
if self.qzeros.shape[0] > 1:
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
if self.groupsize is not None:
assert weight.groupsize == self.groupsize
# Handle act-order matrix
if self.g_idx is not None:
if self.groupsize is None:
raise ValueError("Found group index but no groupsize. What do?")
self.act_order = True
else:
self.act_order = False
DEVICE = self.qweight.device
MAX_DQ = max(MAX_DQ, self.qweight.numel() * 8)
if self.act_order:
MAX_INNER = max(MAX_INNER, self.height, self.width)
ACT_ORDER = True
def forward(self, x):
out = ext_q4_matmul(x, self.q4, self.width)
if self.bias is not None:
out.add_(self.bias)
return out

View File

@ -0,0 +1,267 @@
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from loguru import logger
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.utils.log import log_master
try:
from exllamav2.ext import exllamav2_ext
make_q_matrix = exllamav2_ext.make_q_matrix
gemm_half_q_half = exllamav2_ext.gemm_half_q_half
except ImportError:
log_master(logger.warning, "exllamav2_kernels not installed.")
raise
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
@dataclass
class _ExtraTensors:
"""Additional generated quantizer tensors."""
q_group_map: Optional[torch.Tensor] = None
q_invperm: Optional[torch.Tensor] = None
q_perm: Optional[torch.Tensor] = None
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
"""Matrix multiplication, returns x @ q4"""
output_shape = x.shape[:-1] + (q4_width,)
x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], q4_width), dtype=torch.half, device=x.device)
gemm_half_q_half(x, q_handle, output, force_cuda)
return output.view(output_shape)
def make_group_map(q_groups: torch.Tensor, num_qrows: int):
gr = q_groups.tolist()
group_map = []
num_groups = len(gr) // 2
for i in range(num_groups):
bits = gr[i * 2]
if i < num_groups - 1:
qrows = gr[i * 2 + 3] - gr[i * 2 + 1]
else:
qrows = num_qrows - gr[i * 2 + 1]
rows = qrows * 32 // bits
for j in range(rows):
group_map += [i]
group_map += [rows - j]
return torch.tensor(group_map, dtype=torch.short, device=q_groups.device)
# Create Q matrix
def ext_make_q_matrix(
w: Exl2Weight | GPTQWeight,
extra: _ExtraTensors,
temp_dq,
key: Optional[str] = None,
):
"""
Create Q matrix
"""
# max_dq_size = 512*(1024**2)
# max_dq_rows = max_dq_size // out_features[0]
max_dq_rows = 0
# EXL2
if isinstance(w, Exl2Weight):
extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0])
extra.q_perm = torch.argsort(w.q_invperm).short()
return make_q_matrix(
w.q_weight,
extra.q_perm,
w.q_invperm,
w.q_scale,
w.q_scale_max,
w.q_groups,
extra.q_group_map,
none_tensor, # zeros
none_tensor, # scales
none_tensor, # g_idx
none_tensor, # bias
temp_dq,
max_dq_rows,
)
# GPTQ
elif isinstance(w, GPTQWeight):
if w.scales.dtype == torch.float:
w.scales = w.scales.half()
# GPTQ with g_idx (act_order)
if w.g_idx is not None and not (w.g_idx == 0).all().item():
extra.q_perm = torch.empty(
(w.qweight.shape[0] * 8,),
dtype=torch.short,
device=w.qweight.device,
)
extra.q_invperm = torch.empty_like(extra.q_perm)
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
return make_q_matrix(
w.qweight,
extra.q_perm,
extra.q_invperm,
none_tensor, # q_scale
none_tensor, # q_scale_max
none_tensor, # q_groups
none_tensor, # q_group_map
w.qzeros,
w.scales,
w.g_idx.cpu(),
none_tensor, # bias
temp_dq,
max_dq_rows,
)
# GPTQ without g_idx
else:
return make_q_matrix(
w.qweight,
none_tensor, # q_perm
none_tensor, # q_invperm
none_tensor, # q_scale
none_tensor, # q_scale_max
none_tensor, # q_groups
none_tensor, # q_group_map
w.qzeros,
w.scales,
none_tensor, # g_idx
none_tensor, # bias
temp_dq,
max_dq_rows,
)
else:
RuntimeError("Cannot create handle")
DEVICE = None
LAYERS = []
def set_device(device):
global DEVICE
DEVICE = device
def create_exllama_buffers(max_total_tokens: int):
global LAYERS, DEVICE
# No need to initialize scratch space if there are no layers
# that use ExLLamav2.
if len(LAYERS) == 0:
return
# Find the size of the scratch space.
scratch_bytes = max(
layer.scratch_space_fixed(max_input_len=max_total_tokens, max_batch_size=1)
for layer in LAYERS
)
temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes)
for layer in LAYERS:
layer.post_init(temp_dq)
class QuantLinear(nn.Module):
QUANT_TYPE = "exllamav2"
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(
self,
weight: Exl2Weight | GPTQWeight,
bias: torch.Tensor,
):
super().__init__()
self.q_handle = None
self.q_tensors = weight
self.extra_tensors = _ExtraTensors()
if isinstance(weight, Exl2Weight):
self.infeatures = weight.q_invperm.shape[0]
self.outfeatures = weight.q_weight.shape[1]
elif isinstance(weight, GPTQWeight):
if weight.bits != 4:
raise ValueError(
f"Exllamav2 kernel supports only bits=4, requested bits={weight.bits}. Something is wrong in the model initialization."
)
self.infeatures = weight.qweight.shape[0] // weight.bits * 32
self.outfeatures = weight.qweight.shape[1]
self.padding = -self.outfeatures % 32
self.outfeatures = self.outfeatures + self.padding
self.device = weight.device
self.bias = bias if bias is not None else None
global LAYERS
LAYERS.append(self)
def post_init(self, temp_dq):
device = self.q_tensors.device
assert device.type == "cuda"
assert device.index is not None
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
# We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us,
# and `Memory access fault by GPU node-2` will EAT you.
self.temp_dq = temp_dq
self.q_handle = ext_make_q_matrix(self.q_tensors, self.extra_tensors, temp_dq)
def forward(self, x, force_cuda=False):
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
if self.bias is not None:
output.add_(self.bias)
return output
def temp_dq_size(self):
return self.infeatures * self.outfeatures * 2 + 128
def temp_fwd_size(self, max_input_len, max_batch_size):
return self.outfeatures * max_input_len * max_batch_size * 4 + 128
def scratch_space_fixed(self, max_input_len, max_batch_size):
return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)
class ExLlamaV2DeviceTensors:
device_idx: int
scratch_bytes: int
scratch_idx: int
scratch: torch.tensor = None
def __init__(self, device, scratch_bytes):
self.device = device
self.scratch_bytes = scratch_bytes
def prepare(self):
self.scratch = torch.empty(
(self.scratch_bytes // 2,), dtype=torch.half, device=self.device
)
def get_scratch_slice(self, size_bytes):
if self.scratch is None:
self.prepare()
size_bytes = ((size_bytes + 127) // 128) * 128
size_half = size_bytes // 2
scratch_slice = self.scratch.narrow(0, 0, size_half)
return scratch_slice

View File

@ -1,186 +0,0 @@
import math
import numpy as np
import torch
import torch.nn as nn
try:
convert_from_uint4 = torch.ops.hpu.convert_from_uint4
except Exception as e:
hpu_import_exception = e
def error_raiser_hpu(*args, **kwargs):
raise ValueError(
f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
)
convert_from_uint4 = error_raiser_hpu
def pack_tensor(input, bits=4):
normal = input.to(torch.int32)
q = torch.zeros((normal.shape[0], normal.shape[1] // 32 * bits), dtype=torch.int32)
i = 0
col = 0
while col < q.shape[1]:
for j in range(i, i + (32 // bits)):
q[:, col] |= normal[:, j] << (bits * (j - i))
i += 32 // bits
col += 1
q = q.to(torch.int32)
return q
class QuantLinear(nn.Module):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
super().__init__()
self.register_buffer("qweight", qweight)
self.register_buffer("qzeros", qzeros)
self.register_buffer("scales", scales)
self.register_buffer("g_idx", g_idx)
if bias is not None:
self.register_buffer("bias", bias)
else:
self.bias = None
if bits not in [4]:
raise NotImplementedError("Only 4 bits are supported.")
self.bits = bits
self.maxq = 2**self.bits - 1
self.groupsize = groupsize
self.outfeatures = qweight.shape[1]
self.infeatures = qweight.shape[0] * 32 // bits
self.wf = torch.tensor(
list(range(0, 32, self.bits)), dtype=torch.int32
).unsqueeze(0)
self._preprocessing()
def unpack_zeros_from_cuda_old_format(self):
zeros = torch.bitwise_right_shift(
torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
self.wf.unsqueeze(0),
).to(torch.int16 if self.bits == 8 else torch.int8)
zeros = zeros + 1
zeros = torch.bitwise_and(zeros, (2**self.bits) - 1).to(
self.scales.dtype
) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important.
zeros = zeros.reshape(-1, zeros.shape[1] * zeros.shape[2])
return zeros
def unpack_weight_from_cuda_old_format(self):
weight = torch.bitwise_right_shift(
torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),
self.wf.unsqueeze(-1),
).to(torch.int16 if self.bits == 8 else torch.int8)
weight = torch.bitwise_and(weight, (2**self.bits) - 1)
weight = weight.reshape((weight.shape[0] * weight.shape[1], weight.shape[2]))
return weight
def _preprocessing(self):
orig_device = self.qweight.device
self.qweight = self.qweight.cpu()
weight = self.unpack_weight_from_cuda_old_format()
new_qweight = pack_tensor(weight)
self.qweight = new_qweight.to(orig_device)
# TODO: Support group indexing and remove the check
columns = self.qweight.shape[0]
g_idx_trivial = [i // self.groupsize for i in range(columns)]
g_idx_trivial = torch.tensor(
g_idx_trivial, dtype=torch.int32, device=self.g_idx.device
)
assert torch.equal(
self.g_idx, g_idx_trivial
), "Non-trivial tensor g_idx is not supported"
self.qzeros = self.qzeros.cpu()
zeros = self.unpack_zeros_from_cuda_old_format()
new_qzeros = pack_tensor(zeros)
self.qzeros = new_qzeros.to(orig_device)
@classmethod
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
if bits not in [4]:
raise NotImplementedError("Only 4 bits are supported.")
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
qzeros = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
dtype=torch.int32,
)
scales = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
)
g_idx = torch.tensor(
[i // groupsize for i in range(infeatures)], dtype=torch.int32
)
if bias:
bias = torch.zeros((outfeatures), dtype=torch.float16)
else:
bias = None
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
def pack(self, linear, scales, zeros, g_idx=None):
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
/ self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
x = x.reshape(-1, x.shape[-1])
weight = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
out = torch.matmul(x, weight)
out = out.reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out

View File

@ -0,0 +1,359 @@
import math
import numpy as np
import torch
import torch.nn as nn
from torch.cuda.amp import custom_fwd
import triton
import triton.language as tl
from . import custom_autotune
# code based https://github.com/fpgaminer/GPTQ-triton
@custom_autotune.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=2,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=3,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=2,
num_warps=4,
),
],
key=["M", "N", "K"],
nearest_power_of_two=True,
prune_configs_by={
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
"perf_model": None,
"top_k": None,
},
)
@triton.jit
def matmul_248_kernel(
a_ptr,
b_ptr,
c_ptr,
scales_ptr,
zeros_ptr,
g_ptr,
M,
N,
K,
bits,
maxq,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_scales,
stride_zeros,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, K) float16
B is of shape (K//8, N) int32
C is of shape (M, N) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (
offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = offs_am[:, None] < M
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + (
(offs_k[:, None] // infearure_per_bits) * stride_bk
+ offs_bn[None, :] * stride_bn
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_bn[None, :]
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
shifter = (offs_k % infearure_per_bits) * bits
zeros_shifter = (offs_bn % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, num_pid_k):
g_idx = tl.load(g_ptrs)
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(
scales_ptrs + g_idx[:, None] * stride_scales
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(
zeros_ptrs + g_idx[:, None] * stride_zeros
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1) & maxq # eventually avoid overflow
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
g_ptrs += BLOCK_SIZE_K
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device):
output = torch.empty(
(input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16
)
def grid(META):
return (
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
)
matmul_248_kernel[grid](
input,
qweight,
output,
scales,
qzeros,
g_idx,
input.shape[0],
qweight.shape[1],
input.shape[1],
bits,
maxq,
input.stride(0),
input.stride(1),
qweight.stride(0),
qweight.stride(1),
output.stride(0),
output.stride(1),
scales.stride(0),
qzeros.stride(0),
)
return output
class QuantLinearFunction(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
return output
class QuantLinear(nn.Module):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
super().__init__()
self.register_buffer("qweight", qweight)
self.register_buffer("qzeros", qzeros)
self.register_buffer("scales", scales)
self.register_buffer("g_idx", g_idx)
if bias is not None:
self.register_buffer("bias", bias)
else:
self.bias = None
if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
self.bits = bits
self.maxq = 2**self.bits - 1
self.groupsize = groupsize
self.outfeatures = qweight.shape[1]
self.infeatures = qweight.shape[0] * 32 // bits
@classmethod
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
qzeros = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
dtype=torch.int32,
)
scales = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
)
g_idx = torch.tensor(
[i // groupsize for i in range(infeatures)], dtype=torch.int32
)
if bias:
bias = torch.zeros((outfeatures), dtype=torch.float16)
else:
bias = None
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
def pack(self, linear, scales, zeros, g_idx=None):
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
/ self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
out = QuantLinearFunction.apply(
x.reshape(-1, x.shape[-1]),
self.qweight,
self.scales,
self.qzeros,
self.g_idx,
self.bits,
self.maxq,
)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)

View File

@ -12,7 +12,7 @@ from huggingface_hub import HfApi
from accelerate import init_empty_weights
from text_generation_server.utils import initialize_torch_distributed, Weights
from text_generation_server.utils.hub import weight_files
from text_generation_server.layers.gptq import QuantLinear
from text_generation_server.layers.gptq.quant_linear import QuantLinear
from loguru import logger
from typing import Optional
from text_generation_server.layers.gptq.utils import torch_snr_error
@ -956,24 +956,15 @@ def quantize(
pack(model, quantizers, bits, groupsize)
from safetensors.torch import save_file
from huggingface_hub import split_torch_state_dict_into_shards
from transformers.modeling_utils import shard_checkpoint
state_dict = model.state_dict()
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
max_shard_size = "10GB"
state_dict_split = split_torch_state_dict_into_shards(
state_dict,
filename_pattern="model.safetensors",
max_shard_size=max_shard_size,
shards, index = shard_checkpoint(
state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors"
)
index = None
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
shards = state_dict_split.filename_to_tensors
os.makedirs(output_dir, exist_ok=True)
for shard_file, shard in shards.items():
save_file(

View File

@ -1,6 +1,9 @@
import torch
from torch import nn
from accelerate import init_empty_weights
from text_generation_server.utils.import_utils import (
SYSTEM,
)
# Monkey patching
@ -30,14 +33,69 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps):
torch.nn.LayerNorm.load = load_layer_norm
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
if SYSTEM == "cuda":
import dropout_layer_norm
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
return super().forward(hidden_states), residual
return super(FastLayerNorm, self).forward(hidden_states), residual
else:
(
normed_hidden_states,
residual,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
self.bias,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
if residual is None:
residual = hidden_states
return normed_hidden_states, residual
elif SYSTEM == "rocm":
from vllm._C import ops
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states
return super().forward(hidden_states), residual
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
out = ipex.llm.functional.add_layer_norm(
residual,
hidden_states,
self.weight,
self.bias,
self.eps,
residual is not None,
)
return out, residual if residual is not None else hidden_states
class FastRMSNorm(nn.Module):
@ -53,15 +111,74 @@ class FastRMSNorm(nn.Module):
return cls(weight, eps)
def forward(self, hidden_states, residual=None):
from vllm_hpu_extension.kernels import rms_norm
orig_shape = hidden_states.shape
if residual is not None:
residual += hidden_states.view(residual.shape)
else:
if SYSTEM == "ipex":
out = ipex.llm.functional.add_rms_norm(
residual,
hidden_states,
self.weight,
None,
self.variance_epsilon,
residual is not None,
)
return out, residual if residual is not None else hidden_states
elif hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
# Note: HPUFusedRMSNorm requires 3D tensors as inputs
if len(orig_shape) == 2:
residual = residual.unsqueeze(0)
x = rms_norm().apply(residual, self.weight, self.variance_epsilon)
return x.view(orig_shape), residual.view(orig_shape)
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states, residual
elif SYSTEM == "cuda":
# faster post attention rms norm
(
normed_hidden_states,
res,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states
return normed_hidden_states, res
elif SYSTEM == "rocm":
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states
out = torch.empty_like(hidden_states)
ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)

View File

@ -1,5 +1,21 @@
import torch
from text_generation_server.utils.import_utils import SYSTEM
from torch.nn import functional as F
import os
if SYSTEM == "rocm":
ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
"true",
"1",
)
if ROCM_USE_SKINNY_GEMM:
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(
f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}"
)
class FastLinear(torch.nn.Module):
@ -28,11 +44,83 @@ class FastLinear(torch.nn.Module):
return F.linear(input, self.weight, self.bias)
class FastLinearROCm(torch.nn.Module):
def __init__(
self,
weight,
bias,
) -> None:
super().__init__()
self.weight = torch.nn.Parameter(weight)
if bias is not None:
self.bias = torch.nn.Parameter(bias)
else:
self.bias = None
self.cu_count = torch.cuda.get_device_properties(
device="cuda"
).multi_processor_count
self.use_skinny_gemm = (
ROCM_USE_SKINNY_GEMM
and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName
)
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_tensor(f"{prefix}.weight")
if bias:
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return cls(weight, bias)
def forward(self, inp: torch.Tensor) -> torch.Tensor:
weight = self.weight
bias = self.bias
if (
self.use_skinny_gemm
and inp.dtype == torch.float16
and inp.shape[-1] % 8 == 0
):
batched = False
inp_shape = inp.shape
if inp.dim() == 3:
inp = inp.view(-1, inp_shape[-1])
batched = True
m, n, k = weight.shape[0], inp_shape[0], inp_shape[1]
if m > 8 and n <= 4:
out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
)
_custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
elif m % 4 == 0 and n == 1 and k <= 8192:
out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
)
_custom_C.LLMM1(weight, inp, out, 4)
else:
out = F.linear(inp, weight)
if batched:
out.view(*inp_shape[:-1], out.shape[-1])
if bias is not None:
out = out + bias
return out
return F.linear(inp, self.weight, self.bias)
def get_linear(weight, bias):
# Weights that are loaded through methods that are not
# quantization-aware are still bare tensors. We may want
# to change this in the future.
if isinstance(weight, torch.Tensor):
return FastLinear(weight, bias)
if SYSTEM == "rocm":
return FastLinearROCm(weight, bias)
else:
return FastLinear(weight, bias)
return weight.get_linear(bias)

View File

@ -0,0 +1,15 @@
from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear
from text_generation_server.layers.marlin.gptq import (
GPTQMarlinWeightsLoader,
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
from text_generation_server.layers.marlin.marlin import MarlinWeightsLoader
__all__ = [
"GPTQMarlinFP8Linear",
"GPTQMarlinWeightsLoader",
"MarlinWeightsLoader",
"can_use_gptq_marlin",
"repack_gptq_for_marlin",
]

View File

@ -0,0 +1,140 @@
from typing import Optional
import torch
import torch.nn as nn
from loguru import logger
from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.layers.marlin.gptq import _check_valid_shape
from text_generation_server.layers.marlin.util import (
_check_marlin_kernels,
permute_scales,
)
from text_generation_server.utils.log import log_once
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
MARLIN_TILE_SIZE = 16
class GPTQMarlinFP8Linear(nn.Module):
"""
FP8 GPTQ-Marlin linear layer.
"""
def __init__(
self,
qweight: torch.Tensor,
scales: torch.Tensor,
bias: Optional[torch.Tensor],
) -> None:
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
scales = scales.unsqueeze(0)
if scales.shape[1] == 1:
out_features, in_features = qweight.shape
scales = scales.repeat(1, out_features)
qweight, scales = repack_fp8_for_marlin(qweight, scales)
in_features = qweight.shape[0] * MARLIN_TILE_SIZE
out_features = scales.shape[1]
_check_valid_shape(in_features=in_features, out_features=out_features)
self.qweight = qweight
self.scales = scales
self.bias = bias if bias is not None else None
self.workspace = torch.zeros(
out_features // 64 * 16, dtype=torch.int, device=qweight.device
)
@classmethod
def from_unquant(cls, weight, bias, dtype):
qweight, scales = fp8_quantize(weight)
return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)
@classmethod
def from_fp8(cls, weight, scale, _input_scale, bias, dtype):
return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
A_flat = A.view(-1, A.shape[-1])
C = marlin_kernels.fp8_marlin_gemm(
A_flat,
self.qweight,
self.scales,
self.workspace,
8,
A_flat.shape[0],
self.scales.shape[1],
A_flat.shape[1],
)
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
if self.bias is not None:
C += self.bias
return C
def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
"""
Repack FP8 weights to gptq format (packed int32 elements).
"""
assert fp8_tensor.dtype == torch.float8_e4m3fn
if fp8_tensor.shape[0] % 4 != 0:
raise ValueError(
f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}"
)
# Reshape to prepare for packing
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
# Convert fp8 to uint8 (byte) representation
byte_tensor = reshaped.view(torch.uint8)
# Pack 4 uint8 values into one int32
packed = torch.zeros(
fp8_tensor.shape[0] // 4,
fp8_tensor.shape[1],
dtype=torch.int32,
device=fp8_tensor.device,
)
for i in range(4):
packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8)
return packed
def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor):
"""
Repack FP8 tensor for GPTQ-Marlin.
"""
out_features, in_features = weight.shape
# Torch linear layers weights with shape [out_features, in_features],
# GPTQ-quantized weights use [in_feateres/pack_factor, in_features],
# so transpose before packing.
qweight = pack_fp8_as_int32(weight.t())
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
repacked = marlin_kernels.gptq_marlin_repack(
qweight, perm, in_features, out_features, 8
)
scales = permute_scales(scales)
return repacked, scales

View File

@ -0,0 +1,464 @@
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy
import torch
import torch.nn as nn
from loguru import logger
from text_generation_server.layers.marlin.util import (
_check_marlin_kernels,
marlin_zero_points,
permute_scales,
unpack_cols,
)
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
try:
major, _minor = torch.cuda.get_device_capability()
has_sm_8_0 = major >= 8
except Exception:
has_sm_8_0 = False
GPTQ_MARLIN_BITS = [4, 8]
GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
MARLIN_TILE_SIZE = 16
def can_use_gptq_marlin(
*, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool
) -> bool:
return (
SYSTEM == "cuda"
and marlin_kernels is not None
and has_sm_8_0
and quantize in {"awq", "gptq"}
and quant_method in {"awq", "gptq"}
and bits in GPTQ_MARLIN_BITS
and groupsize in GPTQ_MARLIN_GROUP_SIZES
# We only suppord asymmetric quantization for AWQ.
and (sym or quant_method == "awq")
)
class GPTQMarlinWeightsLoader(WeightsLoader):
"""
Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels.
"""
def __init__(
self,
*,
bits: int,
desc_act: bool,
groupsize: int,
quant_method: str,
quantize: str,
sym: bool,
):
self.bits = bits
self.desc_act = desc_act
self.groupsize = groupsize
self.quant_method = quant_method
self.quantize = quantize
self.sym = sym
def get_weights(self, weights: Weights, prefix: str):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = weights.get_tensor(f"{prefix}.qweight")
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
if not self.sym:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_tensor(f"{prefix}.g_idx")
scales = weights.get_tensor(f"{prefix}.scales")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
try:
qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
)
scales = weights.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes
)
scales = scales.to(dtype=weights.dtype)
if not self.sym:
qzeros = weights.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_tensor(f"{prefix}.g_idx")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
try:
qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
)
scales = torch.cat(
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
if not self.sym:
qzeros = torch.cat(
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
)
def get_weights_row(self, weights: Weights, prefix: str):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
if not self.sym:
if self.desc_act or self.groupsize == -1:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
else:
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
if self.desc_act or self.groupsize == -1:
scales = weights.get_tensor(f"{prefix}.scales")
else:
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
sharded_in_features = weights.process_group.size() > 1
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=sharded_in_features,
)
def _get_gptq_params(self, weights: Weights):
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
self.bits = weights.get_tensor("gptq_bits").item()
self.groupsize = weights.get_tensor("gptq_groupsize").item()
self.desc_act = False
# `server quantize` used asymmetric quantization unconditionally
# before the `gptq_sym` setting tensor was added.
self.sym = (
weights.get_tensor("gptq_sym").item()
if weights._has_tensor("gptq_sym")
else False
)
self.quant_method = "gptq"
@dataclass
class GPTQMarlinWeight(Weight):
"""
Repacked GPTQ Marlin weights.
"""
qweight: torch.Tensor
qzeros: torch.Tensor
scales: torch.Tensor
g_idx: torch.Tensor
perm: torch.Tensor
bits: int
is_full_k: bool
def __post_init__(self):
assert self.qweight.dtype == torch.int32
assert self.scales.dtype == torch.float16
assert self.g_idx.dtype == torch.int32
assert self.perm.dtype == torch.int32
def get_linear(self, bias: torch.Tensor):
return GPTQMarlinLinear(
weight=self,
bias=bias,
)
def repack_gptq_for_marlin(
*,
qweight: torch.Tensor,
qzeros: Optional[torch.Tensor],
scales: torch.Tensor,
g_idx: Optional[torch.Tensor],
bits: int,
desc_act: bool,
groupsize: int,
quant_method: str,
sym: bool,
sharded_infeatures: bool,
) -> GPTQMarlinWeight:
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
_check_marlin_kernels()
assert marlin_kernels is not None
if bits not in GPTQ_MARLIN_BITS:
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
raise RuntimeError(
f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}"
)
if groupsize not in GPTQ_MARLIN_GROUP_SIZES:
supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES)
raise RuntimeError(
f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}"
)
if not (sym or quant_method == "awq"):
raise RuntimeError(
"Repacking GPTQ weights with asymmetric quantization as Marlin is not supported."
)
log_once(logger.info, f"Converting {quant_method} model to Marlin packing format.")
weights_per_int = 32 // bits
in_features = qweight.shape[0]
out_features = qweight.shape[1]
# AWQ uses column packing, GPTQ uses row packing
if quant_method == "awq":
out_features *= weights_per_int
else:
in_features *= weights_per_int
if in_features % groupsize != 0:
raise ValueError(
f"Number of input features ({in_features}) not divisible by group size ({groupsize})"
)
if g_idx is not None and desc_act and groupsize != -1:
perm = torch.argsort(g_idx).to(torch.int)
g_idx = g_idx[perm]
else:
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
if quant_method == "awq":
repacked = marlin_kernels.awq_marlin_repack(
qweight, in_features, out_features, bits
)
if qzeros is not None:
qzeros = awq_to_marlin_zero_points(
qzeros,
in_features // groupsize,
out_features,
bits,
)
else:
repacked = marlin_kernels.gptq_marlin_repack(
qweight, perm, in_features, out_features, bits
)
if qzeros is None:
qzeros = torch.empty(0, dtype=torch.int, device=qweight.device)
scales = permute_scales(scales)
is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)
return GPTQMarlinWeight(
qweight=repacked,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
perm=perm,
bits=bits,
is_full_k=is_full_k,
)
class GPTQMarlinLinear(nn.Module):
"""
Linear layer for GPTQ weights that were converted for the GPTQ-Marlin
kernels.
"""
def __init__(
self,
*,
weight: GPTQMarlinWeight,
bias: Optional[torch.Tensor],
):
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE
out_features = weight.scales.shape[1]
_check_valid_shape(in_features=in_features, out_features=out_features)
self.bits = weight.bits
self.is_full_k = weight.is_full_k
self.qweight = weight.qweight
self.qzeros = weight.qzeros
self.scales = weight.scales
self.g_idx = weight.g_idx
self.perm = weight.perm
if bias is not None:
self.bias = bias
else:
self.bias = None
self.workspace = torch.zeros(
out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
A_flat = A.view(-1, A.shape[-1])
C = marlin_kernels.gptq_marlin_gemm(
A_flat,
self.qweight,
self.scales,
self.qzeros,
self.g_idx,
self.perm,
self.workspace,
self.bits,
A_flat.shape[0],
self.scales.shape[1],
A_flat.shape[1],
self.is_full_k,
self.qzeros.numel() > 0,
True,
)
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
if self.bias is not None:
C += self.bias
return C
def awq_to_marlin_zero_points(
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
# AWQ zero-points are quantized and packed on the column dim.
# In addition, the values are permuted based on dequantizer.
# Here we undo both of these, and then apply marlin permutation
# and pack it back.
q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
# Undo interleaving (use argsort(..) to get inverse perm)
if num_bits == 4:
undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
elif num_bits == 8:
undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
q_zp = q_zp.reshape((-1, size_n)).contiguous()
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
return marlin_zp
def _check_valid_shape(in_features: int, out_features: int):
if (in_features % 128 != 0 or out_features % 64 != 0) and (
in_features % 64 != 0 or out_features % 128 != 0
):
raise ValueError(
f"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features})."
" The shape elements must be divisible by (128, 64) or (64, 128)."
)

View File

@ -0,0 +1,346 @@
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
import torch.nn as nn
from text_generation_server.layers.marlin.util import _check_marlin_kernels
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
class MarlinWeightsLoader(WeightsLoader):
"""Loader for Marlin-quantized weights."""
def __init__(self, *, bits: int, is_marlin_24: bool):
self.bits = bits
self.is_marlin_24 = is_marlin_24
def get_weights(self, weights: "Weights", prefix: str):
"""
Get weights at the given prefix and apply without tensor paralllism.
"""
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
B = weights.get_tensor(f"{prefix}.B_24")
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
)
B_meta = weights.get_tensor(f"{prefix}.B_meta")
s = weights.get_tensor(f"{prefix}.s")
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
try:
B = weights.get_tensor(f"{prefix}.B")
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized."
)
s = weights.get_tensor(f"{prefix}.s")
weight = MarlinWeight(B=B, s=s)
return weight
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
if self.is_marlin_24:
B = weights.get_packed_sharded(
f"{prefix}.B_24", dim=1, block_sizes=block_sizes
)
B_meta = weights.get_packed_sharded(
f"{prefix}.B_meta", dim=1, block_sizes=block_sizes
)
s = weights.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
B = weights.get_packed_sharded(
f"{prefix}.B", dim=1, block_sizes=block_sizes
)
s = weights.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
weight = MarlinWeight(B=B, s=s)
return weight
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
if self.is_marlin_24:
try:
B = torch.cat(
[weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized"
)
B_meta = torch.cat(
[weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1
)
s = torch.cat(
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
try:
B = torch.cat(
[weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized"
)
s = torch.cat(
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
weight = MarlinWeight(B=B, s=s)
return weight
def get_weights_row(self, weights: Weights, prefix: str):
if self.is_marlin_24:
try:
B = weights.get_sharded(f"{prefix}.B_24", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
)
B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0)
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = weights.get_tensor(f"{prefix}.s")
else:
s = weights.get_sharded(f"{prefix}.s", dim=0)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
try:
B = weights.get_sharded(f"{prefix}.B", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized."
)
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = weights.get_tensor(f"{prefix}.s")
else:
s = weights.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s)
return weight
@dataclass
class MarlinWeight(Weight):
"""
Marlin weights.
Attributes:
B (torch.Tensor): int4-quantized weights packed into int32.
s (torch.Tensor): bfloat16/float16 scales.
"""
B: torch.Tensor
s: torch.Tensor
def __post_init__(self):
assert self.B.dtype == torch.int32
assert self.s.dtype in [torch.float16, torch.bfloat16]
def get_linear(self, bias: torch.Tensor):
return MarlinLinear(weight=self, bias=bias)
class MarlinLinear(nn.Module):
def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]):
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE
out_features = weight.s.shape[1]
assert (
in_features % 128 == 0
), f"Number of input features ({in_features}) not divisable by 128"
assert (
out_features % 256 == 0
), f"Number of output features ({out_features}) not divisable by 256"
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
assert groupsize in {
-1,
128,
}, f"Group size must be -1 or 128, was {groupsize}"
self.B = weight.B
self.s = weight.s
if bias is not None:
self.bias = bias
else:
self.bias = None
self.workspace = torch.zeros(
out_features // 64 * 16, dtype=torch.int, device=weight.B.device
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
C = marlin_kernels.marlin_gemm(
A.view(-1, A.shape[-1]),
self.B,
self.s,
self.workspace,
A.shape[0],
self.s.shape[1],
A.shape[1],
)
C = C.reshape(A.shape[:-1] + (self.s.shape[1],))
if self.bias is not None:
C += self.bias
return C
GPTQ_MARLIN_24_MIN_THREAD_N = 128
GPTQ_MARLIN_24_MIN_THREAD_K = 128
GPTQ_MARLIN_24_MAX_PARALLEL = 64
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
MARLIN_TILE_SIZE = 16
@dataclass
class GPTQMarlin24Weight:
"""
GPTQ-Marlin 2:4 weights.
Attributes:
B (torch.Tensor): int4-quantized weights packed into int32.
B_meta (torch.Tensor): metadata for 2:4 sparsity.
s (torch.Tensor): float16 scales.
bits: quantized weight size.
"""
B: torch.Tensor
B_meta: torch.Tensor
s: torch.Tensor
bits: int
def __post_init__(self):
assert self.B.dtype == torch.int32
assert self.B_meta.dtype == torch.int16
assert self.s.dtype == torch.float16
def get_linear(self, bias: torch.Tensor):
return GPTQMarlin24Linear(
weight=self,
bias=bias,
)
class GPTQMarlin24Linear(nn.Module):
def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]):
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
supported_bits = ", ".join(
str(b) for b in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
)
raise RuntimeError(
f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}"
)
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE * 2
out_features = weight.s.shape[1]
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
supported_sizes = ", ".join(
str(b) for b in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
)
raise RuntimeError(
f"Group size {groupsize} is not supported, must be one of: {supported_sizes}"
)
self.bits = weight.bits
weights_per_int32 = 32 // self.bits
assert (
out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0
), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_N} threads"
assert (
out_features % weights_per_int32 == 0
), f"Number of output features ({out_features}) not divisable by weights per int32 ({weights_per_int32})"
assert (
in_features % GPTQ_MARLIN_24_MIN_THREAD_K == 0
), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_K} threads"
if groupsize != -1 and in_features % groupsize != 0:
raise ValueError(
f"Number of input features ({in_features}) not divisable by group size ({groupsize})"
)
self.B = weight.B
self.B_meta = weight.B_meta
self.s = weight.s
if bias is not None:
self.bias = bias
else:
self.bias = None
self.workspace = torch.zeros(
(out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL,
dtype=torch.int,
device=weight.B.device,
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
C = marlin_kernels.gptq_marlin_24_gemm(
A.view(-1, A.shape[-1]),
self.B,
self.B_meta,
self.s,
self.workspace,
self.bits,
A.shape[0],
self.s.shape[1],
A.shape[1],
)
C = C.reshape(A.shape[:-1] + (self.s.shape[1],))
if self.bias is not None:
C += self.bias
return C

View File

@ -0,0 +1,141 @@
import functools
from typing import List, Tuple
import numpy
import torch
from text_generation_server.utils.import_utils import SYSTEM
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
try:
major, _minor = torch.cuda.get_device_capability()
has_sm_8_0 = major >= 8
except Exception:
has_sm_8_0 = False
def _check_marlin_kernels():
if not (SYSTEM == "cuda" and has_sm_8_0):
raise NotImplementedError(
"Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later."
)
if marlin_kernels is None:
raise NotImplementedError(
"marlin is not installed, install it with: pip install server/marlin"
)
# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54
@functools.cache
def get_perms() -> Tuple[List[int], List[int]]:
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single
def permute_scales(scales: torch.Tensor):
scale_perm, scale_perm_single = get_perms()
out_features = scales.shape[1]
if scales.shape[0] == 1:
scales = scales.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
else:
scales = scales.reshape((-1, len(scale_perm)))[:, scale_perm]
return scales.reshape((-1, out_features)).contiguous()
# Functions below are from vLLM
def get_pack_factor(bits: int) -> int:
if 32 % bits != 0:
raise ValueError(f"Cannot {bits} bit values into uint32")
return 32 // bits
def pack_cols(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == (size_k, size_n)
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
for i in range(pack_factor):
q_res |= q_w[:, i::pack_factor] << num_bits * i
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def unpack_cols(
packed_q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
assert packed_q_w.shape == (
size_k,
size_n // pack_factor,
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
packed_q_w.shape, size_k, size_n, pack_factor
)
orig_device = packed_q_w.device
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
mask = (1 << num_bits) - 1
for i in range(pack_factor):
vals = packed_q_w_cpu & mask
packed_q_w_cpu >>= num_bits
q_res[:, i::pack_factor] = vals
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def marlin_zero_points(
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
scale_perm, _ = get_perms()
# Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA
zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
# Interleave column dim (for the dequantize code) and pack it to int32
if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
zp = zp.reshape((-1, size_n)).contiguous()
zp = pack_cols(zp, num_bits, size_k, size_n)
return zp

View File

@ -10,8 +10,13 @@ from text_generation_server.layers import (
TensorParallelRowLinear,
)
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader
from text_generation_server.layers.moe.gptq_marlin import (
GPTQMarlinSparseMoELayer,
can_use_marlin_moe_gemm,
)
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
@ -19,7 +24,12 @@ from text_generation_server.utils.weights import (
UnquantizedWeight,
)
from .fused_moe import fused_topk, grouped_topk
if SYSTEM == "rocm":
from .fused_moe_rocm import grouped_topk
from vllm.model_executor.layers.fused_moe import fused_topk
elif SYSTEM != "ipex":
from moe_kernels.fused_moe import fused_topk, grouped_topk
# NOTE: we are using a protocol here, because multiple inherance is not nice.
# We need `Module`, and `Module` -> some abstract class -> some concrete
@ -42,8 +52,6 @@ class MoELayer(Protocol):
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
hidden_act: str = "silu",
scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None,
): ...
def forward(
@ -73,14 +81,9 @@ class DenseMoELayer(nn.Module):
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
hidden_act: str = "silu",
scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None,
):
super().__init__()
assert scoring_func is None, "scoring func is not handled"
assert e_score_correction_bias is None, "scoring correction bias is not handled"
log_once(
logger.info,
"No fused layers are available for this model type, using (slower) dense MoE layer",
@ -196,27 +199,22 @@ class SparseMoELayer(nn.Module):
topk: int,
topk_group: Optional[int],
weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()
if (
isinstance(weights.loader, DefaultWeightsLoader)
and isinstance(weights.loader.weight_class, UnquantizedWeight)
) or isinstance(weights.loader, HybridFP8UnquantLoader):
if (
isinstance(weights.loader, HybridFP8UnquantLoader)
and weights.loader.to_fp8
):
cls = FP8SparseMoELayer
else:
cls = UnquantizedSparseMoELayer
cls = UnquantizedSparseMoELayer
elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym:
cls = GPTQMarlinSparseMoELayer
else:
raise ValueError(
f"Unsupported weights loader: {type(weights.loader)}, sparse MoE is only supported for unquantized, AWQ, and GPTQ weights"
f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights"
)
log_once(
@ -232,8 +230,6 @@ class SparseMoELayer(nn.Module):
topk=topk,
topk_group=topk_group,
weights=weights,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name,
down_proj_name=down_proj_name,
@ -245,6 +241,17 @@ class SparseMoELayer(nn.Module):
@staticmethod
def is_supported(weights: Weights) -> bool:
return (
isinstance(weights.loader, DefaultWeightsLoader)
and isinstance(weights.loader.weight_class, UnquantizedWeight)
) or isinstance(weights.loader, HybridFP8UnquantLoader)
(
isinstance(weights.loader, DefaultWeightsLoader)
and isinstance(weights.loader.weight_class, UnquantizedWeight)
)
or isinstance(weights.loader, HybridFP8UnquantLoader)
or (
isinstance(weights.loader, GPTQMarlinWeightsLoader)
and can_use_marlin_moe_gemm(
quant_method=weights.loader.quant_method,
quantize=weights.loader.quantize,
sym=weights.loader.sym,
)
)
)

View File

@ -1,173 +0,0 @@
from typing import Optional
import torch
import torch.nn as nn
from text_generation_server.utils.weights import Weights
from text_generation_server.layers.fp8 import (
Fp8Weight,
fp8_quantize,
quant_dtype,
normalize_e4m3fn_to_native_float8,
)
try:
from .unquantized import fused_moe
except Exception:
fused_moe = None
class FP8SparseMoELayer(nn.Module):
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()
assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"
self.n_expert_group = n_expert_group
self.topk = topk
self.topk_group = topk_group
self.renormalize = renormalize
self.weight_block_size = weights.weights_loader.weight_block_size
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
(
self.gate_up_proj,
self.gate_up_proj_weight_scale,
self.gate_up_proj_input_scale,
) = _load_expert_multi_weights_col(
prefix=prefix,
n_experts=n_experts,
gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name,
weights=weights,
)
self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
_load_expert_weights_row(
prefix=prefix,
n_experts=n_experts,
name=down_proj_name,
weights=weights,
)
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return fused_moe(
x,
w1=self.gate_up_proj,
w2=self.down_proj,
gating_output=gating_output,
topk=self.topk,
renormalize=self.renormalize,
inplace=True,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
use_fp8_w8a8=True,
w1_scale=self.gate_up_proj_weight_scale,
w2_scale=self.down_proj_weight_scale,
a1_scale=self.gate_up_proj_input_scale,
a2_scale=self.down_proj_input_scale,
)
def _load_expert_weights(
get_weight_fn,
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
) -> torch.Tensor:
all_weight = None
all_weight_scales = None
max_input_scale = None
for i in range(n_experts):
weight = get_weight_fn(prefix, i, name, weights)
assert isinstance(weight, Fp8Weight)
if all_weight is None:
all_weight = torch.empty(
(n_experts,) + weight.weight.shape,
dtype=quant_dtype,
device=weight.weight.device,
)
if all_weight_scales is None:
all_weight_scales = torch.empty(
(n_experts,) + weight.weight_scale.shape,
dtype=torch.float32,
device=weight.weight.device,
)
if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}:
all_weight[i], all_weight_scales[i], current_input_scale = (
normalize_e4m3fn_to_native_float8(
weight.weight, weight.weight_scale, weight.input_scale
)
)
if current_input_scale is not None:
if max_input_scale is None or current_input_scale > max_input_scale:
max_input_scale = current_input_scale
else:
all_weight[i], all_weight_scales[i] = fp8_quantize(
weight.weight, scalar=True
)
assert all_weight is not None
return all_weight, all_weight_scales, max_input_scale
def _load_expert_multi_weights_col(
*,
prefix: str,
n_experts: int,
gate_proj_name: str,
up_proj_name: str,
weights: Weights,
) -> torch.Tensor:
def get_weight_fn(prefix, i, name, weights):
return weights.get_multi_weights_col(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
return _load_expert_weights(
get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights
)
def _load_expert_weights_row(
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
) -> torch.Tensor:
def get_weight_fn(prefix, i, name, weights):
return weights.get_weights_row(f"{prefix}.{i}.{name}")
return _load_expert_weights(
get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights
)

View File

@ -16,8 +16,10 @@
from typing import Tuple
import torch
import torch.distributed
# TODO: Remove the functions once moe_kernel are built for ROCM
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
@ -48,18 +50,3 @@ def grouped_topk(
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
topk_weights = torch.nn.functional.softmax(
gating_output, dim=1, dtype=torch.float32
)
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
if renormalize:
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids

View File

@ -0,0 +1,215 @@
from dataclasses import dataclass
from typing import List, Optional
import torch
import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import Weights
from text_generation_server.layers.marlin.gptq import (
GPTQMarlinWeight,
GPTQMarlinWeightsLoader,
)
if SYSTEM == "cuda":
from moe_kernels.fused_marlin_moe import fused_marlin_moe
else:
fused_marlin_moe = None
try:
major, _minor = torch.cuda.get_device_capability()
has_sm_8_0 = major >= 8
except Exception:
has_sm_8_0 = False
def can_use_marlin_moe_gemm(
*,
quant_method: str,
quantize: str,
sym: bool,
):
return (
SYSTEM == "cuda"
and fused_marlin_moe is not None
and has_sm_8_0
and quantize == "gptq"
and quant_method == "gptq"
and sym
)
@dataclass
class GPTQMarlinMoEWeight:
qweight: torch.Tensor
qzeros: torch.Tensor
scales: torch.Tensor
g_idx: torch.Tensor
perm: torch.Tensor
is_full_k: bool
class GPTQMarlinSparseMoELayer(nn.Module):
"""
MoE layer that uses a fused GPTQ-Marlin kernel.
"""
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()
if not (
isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym
):
raise ValueError(
f"Unsupported weights loader: {weights.loader}, only GPTQMarlinWeightsLoader with symmetric quantization is supported"
)
assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"
self.n_expert_group = n_expert_group
self.topk = topk
self.topk_group = topk_group
self.renormalize = renormalize
self.gate_up_proj = _load_expert_multi_weights_col(
prefix=prefix,
n_experts=n_experts,
names=[gate_proj_name, up_proj_name],
weights=weights,
)
self.down_proj = _load_expert_weights_row(
prefix=prefix, n_experts=n_experts, name=down_proj_name, weights=weights
)
self.bits = weights.loader.bits
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return fused_marlin_moe(
x,
w1=self.gate_up_proj.qweight,
w2=self.down_proj.qweight,
g_idx1=self.gate_up_proj.g_idx,
g_idx2=self.down_proj.g_idx,
perm1=self.gate_up_proj.perm,
perm2=self.down_proj.perm,
w1_scale=self.gate_up_proj.scales,
w2_scale=self.down_proj.scales,
is_full_k1=self.gate_up_proj.is_full_k,
is_full_k2=self.down_proj.is_full_k,
gating_output=gating_output,
topk=self.topk,
renormalize=self.renormalize,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
num_bits=self.bits,
)
def _load_expert_multi_weights_col(
*,
prefix: str,
n_experts: int,
names: List[str],
weights: Weights,
) -> GPTQMarlinMoEWeight:
moe_weight = None
for i in range(n_experts):
weight = weights.get_multi_weights_col(
[f"{prefix}.{i}.{name}" for name in names], 0
)
assert isinstance(weight, GPTQMarlinWeight)
moe_weight = _pack_weight(
n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight
)
assert moe_weight is not None
return moe_weight
def _load_expert_weights_row(
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
) -> GPTQMarlinMoEWeight:
moe_weight = None
for i in range(n_experts):
weight = weights.get_weights_row(
f"{prefix}.{i}.{name}",
)
assert isinstance(weight, GPTQMarlinWeight)
moe_weight = _pack_weight(
n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight
)
assert moe_weight is not None
return moe_weight
def _pack_weight(
*,
n_experts: int,
expert: int,
moe_weight: Optional[GPTQMarlinMoEWeight],
weight: GPTQMarlinWeight,
) -> GPTQMarlinMoEWeight:
if moe_weight is None:
qweight = torch.empty(
(n_experts,) + weight.qweight.shape,
dtype=weight.qweight.dtype,
device=weight.qweight.device,
)
qzeros = torch.empty(
(n_experts,) + weight.qzeros.shape,
dtype=weight.qzeros.dtype,
device=weight.qzeros.device,
)
scales = torch.empty(
(n_experts,) + weight.scales.shape,
dtype=weight.scales.dtype,
device=weight.scales.device,
)
g_idx = torch.empty(
(n_experts,) + weight.g_idx.shape,
dtype=weight.g_idx.dtype,
device=weight.g_idx.device,
)
perm = torch.empty(
(n_experts,) + weight.perm.shape,
dtype=weight.perm.dtype,
device=weight.perm.device,
)
moe_weight = GPTQMarlinMoEWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
perm=perm,
is_full_k=weight.is_full_k,
)
moe_weight.qweight[expert] = weight.qweight
moe_weight.qzeros[expert] = weight.qzeros
moe_weight.scales[expert] = weight.scales
moe_weight.g_idx[expert] = weight.g_idx
moe_weight.perm[expert] = weight.perm
return moe_weight

View File

@ -3,8 +3,13 @@ from typing import Optional
import torch
import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import UnquantizedWeight, Weights
from vllm_hpu_extension.ops import DynamicFusedMOE
if SYSTEM == "rocm":
from vllm.model_executor.layers.fused_moe import fused_moe
elif SYSTEM != "ipex":
from moe_kernels.fused_moe import fused_moe
class UnquantizedSparseMoELayer(nn.Module):
@ -18,8 +23,6 @@ class UnquantizedSparseMoELayer(nn.Module):
topk: int,
topk_group: Optional[int],
weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
@ -34,9 +37,6 @@ class UnquantizedSparseMoELayer(nn.Module):
self.topk = topk
self.topk_group = topk_group
self.renormalize = renormalize
self.weight_block_size = weights.weights_loader.weight_block_size
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.gate_up_proj = _load_expert_multi_weights_col(
prefix=prefix,
@ -53,13 +53,30 @@ class UnquantizedSparseMoELayer(nn.Module):
weights=weights,
)
self.hpu_fused_moe = DynamicFusedMOE(n_experts)
for i in range(n_experts):
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i])
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return self.hpu_fused_moe(x, gating_output, self.topk)
if SYSTEM == "rocm":
return fused_moe(
x,
self.gate_up_proj,
self.down_proj,
gating_output,
self.topk,
renormalize=self.renormalize,
inplace=True,
)
return fused_moe(
x,
w1=self.gate_up_proj,
w2=self.down_proj,
gating_output=gating_output,
topk=self.topk,
renormalize=self.renormalize,
inplace=True,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
def _load_expert_multi_weights_col(

View File

@ -2,10 +2,14 @@ import os
import math
import torch
from torch import nn
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
apply_rotary_pos_emb,
)
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "cuda":
import rotary_emb
elif SYSTEM == "rocm":
from vllm._C import ops
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
def _create_inv_freq(dim, base, device):
@ -26,7 +30,7 @@ def _get_rope_config(config):
class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq, scaling_factor, max_position_embeddings):
def __init__(self, inv_freq, scaling_factor):
super().__init__()
self.inv_freq = inv_freq
self._seq_len_cached = 0
@ -36,9 +40,6 @@ class PositionRotaryEmbedding(nn.Module):
self._sin_k_cached = None
self.scaling_factor = scaling_factor
self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, inv_freq.device, max_position_embeddings
)
def forward(
self,
@ -47,41 +48,40 @@ class PositionRotaryEmbedding(nn.Module):
cos: torch.Tensor,
sin: torch.Tensor,
):
num_tokens = query.shape[0]
head_size = query.shape[-1]
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
# to query hidden dimension, so the original tensors need to be
# expanded
# GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
# and expansion of cos/sin tensors via concatenation
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
cos = torch.cat((cos, cos), dim=-1)
sin = torch.cat((sin, sin), dim=-1)
rotary_dim = cos.shape[-1]
query_shape = query.shape
query = query.view(num_tokens, -1, head_size)
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
# Such controlflows may add some overhead.
if SYSTEM == "cuda":
rotary_dim = cos.shape[-1]
q1 = query[..., :rotary_dim]
q2 = query[..., rotary_dim : 2 * rotary_dim]
key_shape = key.shape
key = key.view(num_tokens, -1, head_size)
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
k1 = key[..., :rotary_dim]
k2 = key[..., rotary_dim : 2 * rotary_dim]
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif SYSTEM == "rocm":
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
head_size = query.shape[-1]
# Inplace operation, updating query and key.
ops.rotary_embedding(query, key, head_size, cos, sin, True)
elif SYSTEM == "ipex":
ipex.llm.functional.rotary_embedding(
query, key, sin, cos, query.size(-1), True
)
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)
@classmethod
def static(cls, config, dim, base, device):
inv_freq = _create_inv_freq(dim, base, device)
scaling_factor = None
rope_scaling = _get_rope_config(config)
if not hasattr(config, "max_position_embeddings") and hasattr(
config, "max_seq_len"
):
# handling for dbrx
config.max_position_embeddings = config.max_seq_len
if rope_scaling is not None:
# `rope_type` is now standard in transformers, but some existing models
# have `type` instead.
@ -89,17 +89,6 @@ class PositionRotaryEmbedding(nn.Module):
if rope_type == "linear":
pass
elif rope_type == "default":
pass
elif rope_type == "mrope":
mrope_section = rope_scaling["mrope_section"]
if mrope_section is not None:
return RotaryPositionEmbeddingMultimodalSections(
inv_freq,
scaling_factor,
mrope_section,
config.max_position_embeddings,
)
elif rope_type == "dynamic":
scaling_factor = rope_scaling["factor"]
return DynamicPositionRotaryEmbedding(
@ -120,7 +109,7 @@ class PositionRotaryEmbedding(nn.Module):
],
)
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
return cls(inv_freq, scaling_factor)
elif rope_type == "yarn":
scaling_factor = rope_scaling["factor"]
@ -196,13 +185,12 @@ class PositionRotaryEmbedding(nn.Module):
long_inv_freq=long_inv_freq,
scaling_factor=scaling_factor,
original_max_position_embeddings=original_max_position_embeddings,
max_position_embeddings=config.max_position_embeddings,
)
else:
raise NotImplementedError(
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
)
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
return cls(inv_freq, scaling_factor)
@classmethod
def load(cls, config, prefix, weights):
@ -248,7 +236,7 @@ class PositionRotaryEmbedding(nn.Module):
raise NotImplementedError(
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
)
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
return cls(inv_freq, scaling_factor)
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
@ -269,7 +257,17 @@ class PositionRotaryEmbedding(nn.Module):
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
def get_cos_sin(self, position_ids: torch.Tensor):
def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype):
"""
Return cos and sin for the asked position ids
"""
if SYSTEM == "rocm":
# For RoCm, we always use float cos/sin to avoid a cast.
# For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26
# But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.
dtype = torch.float32
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids)
@ -285,7 +283,6 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
long_inv_freq,
scaling_factor,
original_max_position_embeddings,
max_position_embeddings,
):
super(PositionRotaryEmbedding, self).__init__()
self.short_inv_freq = short_inv_freq
@ -298,9 +295,6 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
self._cos_k_cached = None
self._sin_k_cached = None
self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, short_inv_freq.device, max_position_embeddings
)
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
@ -354,9 +348,6 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
self._cos_k_cached = None
self._sin_k_cached = None
self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, short_inv_freq.device, max_position_embeddings
)
def _update_cos_sin_cache(self, dtype, device, seqlen):
if (
@ -392,7 +383,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = _create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
super().__init__(inv_freq, scaling_factor)
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
@ -470,9 +461,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
mscale_all_dim: float,
):
inv_freq = _create_inv_freq(dim, base, device)
super().__init__(
inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor
)
super().__init__(inv_freq, scaling_factor)
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
@ -557,50 +546,3 @@ def apply_llama3_scaling(
new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
def __init__(
self,
inv_freq: torch.Tensor,
scaling_factor: float,
sections: list,
max_position_embeddings,
):
self.sections = sections
self._cos_cached = None
self._sin_cached = None
self.section_indices = (
torch.arange(len(self.sections))
.repeat_interleave(torch.tensor(self.sections))
.view(1, 1, -1)
.to(inv_freq.device)
)
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
def _update_cos_sin_cache(
self, dtype: torch.dtype, device: torch.device, seqlen: int
):
# always cache the cos/sin for the full sequence length to avoid
# recomputing if the sequence length is smaller than the cached one
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
self._sections = self.section_indices.expand(seqlen, -1, -1)
def get_cos_sin(
self,
position_ids: torch.Tensor,
):
slen = position_ids.shape[0]
cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])
return cos, sin

View File

@ -2,8 +2,10 @@ import torch
from torch.nn import functional as F
from typing import Iterable, List
from text_generation_server.layers.linear import get_linear, FastLinear
from text_generation_server.utils.import_utils import SYSTEM
import habana_frameworks.torch as htorch
if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
class LayerConcat(torch.nn.Module):
@ -88,10 +90,14 @@ class TensorParallelHead(SuperLayer):
local_out = gather_input.T
torch.mm(input, self.linear.weight.T, out=local_out)
htorch.core.mark_step()
torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
if SYSTEM == "ipex":
ipex.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
else:
torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
if input.shape[0] == 1:
return world_out
@ -101,9 +107,10 @@ class TensorParallelHead(SuperLayer):
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
htorch.core.mark_step()
torch.distributed.all_gather(world_output, output, group=self.process_group)
if SYSTEM == "ipex":
ipex.distributed.all_gather(world_output, output, group=self.process_group)
else:
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)
return world_output
@ -195,11 +202,10 @@ class TensorParallelRowLinear(SuperLayer):
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
out = super().forward(input)
if self.process_group.size() > 1 and reduce:
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
# (which is required for tensor parallel HPUGraph inference)
htorch.core.mark_step()
torch.distributed.all_reduce(out, group=self.process_group)
if SYSTEM == "ipex":
ipex.distributed.all_reduce(out, group=self.process_group)
else:
torch.distributed.all_reduce(out, group=self.process_group)
return out
@ -236,9 +242,8 @@ class TensorParallelEmbedding(torch.nn.Module):
)
out = torch.nn.functional.embedding(input, self.weight)
if self.reduce and self.process_group.size() > 1:
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
# (which is required for tensor parallel HPUGraph inference)
htorch.core.mark_step()
torch.distributed.all_reduce(out, group=self.process_group)
if SYSTEM == "ipex":
ipex.distributed.all_reduce(out, group=self.process_group)
else:
torch.distributed.all_reduce(out, group=self.process_group)
return out

View File

@ -1,5 +1,3 @@
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables
import torch
import os
@ -10,7 +8,6 @@ from huggingface_hub import hf_hub_download, HfApi
from typing import Optional
from pathlib import Path
from typing import List, Dict
import enum
# Needed to properly setup habana_frameworks
@ -19,10 +16,17 @@ from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.bloom import BLOOM
from text_generation_server.models.starcoder import StarCoder
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
PhiMoEConfig,
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
# from text_generation_server.models.custom_modeling.mllama import (
# MllamaForConditionalGeneration,
# )
from text_generation_server.utils.adapter import (
AdapterParameters,
build_layer_weight_lookup,
@ -31,287 +35,10 @@ from text_generation_server.utils.adapter import (
)
from text_generation_server.adapters.lora import LoraWeights
from text_generation_server.utils.log import log_master
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
__all__ = [
"Model",
"CausalLM",
"Seq2SeqLM",
"get_model_with_lora_adapters",
]
from text_generation_server.models.globals import ATTENTION
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
FLASH_ATTENTION = False
if ATTENTION == "paged":
FLASH_ATTENTION = True
try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLM
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
FlashDeepseekV2ForCausalLM,
DeepseekV2Config,
)
from text_generation_server.models.custom_modeling.flash_deepseek_v3_modeling import (
FlashDeepseekV3ForCausalLM,
DeepseekV3Config,
)
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
FlashCohereForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
FlashDbrxForCausalLM,
DbrxConfig,
)
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
RWConfig,
FlashRWForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
FlashGPTNeoXForCausalLM,
)
from text_generation_server.models.pali_gemma import (
PaliGemmaBatch,
)
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM,
)
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
from text_generation_server.models.custom_modeling.flash_mllama import (
FlashMllamaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_llava_next import (
FlashLlavaNextForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
FlashSantacoderForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
FlashStarcoder2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
FlashMixtralForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
FlashGPT2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gptj_modeling import (
FlashGPTJForCausalLM,
)
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.idefics3 import (
Idefics3ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.qwen2_vl import (
Qwen2VLForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.qwen2_5_vl import (
Qwen2_5VLForConditionalGeneration,
Qwen2_5_VLConfig,
Qwen2_5_VLProcessor,
)
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
except ImportError as e:
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
SUPPORTS_WINDOWING = False
FLASH_ATTENTION = False
if FLASH_ATTENTION:
__all__.append(FlashCausalLM)
class ModelType(enum.Enum):
DEEPSEEK_V2 = {
"type": "deepseek_v2",
"name": "Deepseek V2",
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
}
DEEPSEEK_V3 = {
"type": "deepseek_v3",
"name": "Deepseek V3",
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V3",
}
IDEFICS2 = {
"type": "idefics2",
"name": "Idefics 2",
"url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
"multimodal": True,
}
IDEFICS3 = {
"type": "idefics3",
"name": "Idefics 3",
"url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3",
"multimodal": True,
}
LLAVA_NEXT = {
"type": "llava_next",
"name": "Llava Next (1.6)",
"url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
"multimodal": True,
}
LLAMA = {
"type": "llama",
"name": "Llama",
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
}
PHI3 = {
"type": "phi3",
"name": "Phi 3",
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
}
GRANITE = {
"type": "granite",
"name": "Granite",
"url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct",
}
GEMMA = {
"type": "gemma",
"name": "Gemma",
"url": "https://huggingface.co/google/gemma-7b",
}
PALIGEMMA = {
"type": "paligemma",
"name": "PaliGemma",
"url": "https://huggingface.co/google/paligemma-3b-pt-224",
}
GEMMA2 = {
"type": "gemma2",
"name": "Gemma2",
"url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
}
COHERE = {
"type": "cohere",
"name": "Cohere",
"url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
}
DBRX = {
"type": "dbrx",
"name": "Dbrx",
"url": "https://huggingface.co/databricks/dbrx-instruct",
}
MAMBA = {
"type": "mamba",
"name": "Mamba",
"url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
}
MISTRAL = {
"type": "mistral",
"name": "Mistral",
"url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407",
}
MIXTRAL = {
"type": "mixtral",
"name": "Mixtral",
"url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
}
GPT_BIGCODE = {
"type": "gpt_bigcode",
"name": "Gpt Bigcode",
"url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
}
PHI = {
"type": "phi",
"name": "Phi",
"url": "https://huggingface.co/microsoft/phi-1_5",
}
PHI_MOE = {
"type": "phimoe",
"name": "PhiMoe",
"url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
}
BAICHUAN = {
"type": "baichuan",
"name": "Baichuan",
"url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
}
FALCON = {
"type": "falcon",
"name": "Falcon",
"url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
}
STARCODER2 = {
"type": "starcoder2",
"name": "StarCoder 2",
"url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
}
QWEN2 = {
"type": "qwen2",
"name": "Qwen 2",
"url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
}
QWEN2_VL = {
"type": "qwen2_vl",
"name": "Qwen 2 VL",
"url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d",
}
QWEN2_5_VL = {
"type": "qwen2_5_vl",
"name": "Qwen 2.5 VL",
"url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e",
}
GALACTICA = {
"type": "galactica",
"name": "Galactica",
"url": "https://huggingface.co/facebook/galactica-120b",
}
SANTACODER = {
"type": "santacoder",
"name": "SantaCoder",
"url": "https://huggingface.co/bigcode/santacoder",
}
GPT2 = {
"type": "gpt2",
"name": "Gpt2",
"url": "https://huggingface.co/openai-community/gpt2",
}
GPT_NEOX = {
"type": "gpt_neox",
"name": "Gpt Neox",
"url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
}
GPTJ = {
"type": "gptj",
"name": "Gptj",
"url": "https://huggingface.co/EleutherAI/gpt-j-6b",
}
MLLAMA = {
"type": "mllama",
"name": "Mllama",
"url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
"multimodal": True,
}
__GLOBALS = locals()
for data in ModelType:
__GLOBALS[data.name] = data.value["type"]
SDP_ON_BF16 = int(os.environ.get("SDP_ON_BF16", 0))
# Disable gradients
torch.set_grad_enabled(False)
@ -327,7 +54,7 @@ def get_model(
trust_remote_code: bool,
max_input_tokens: int,
) -> Model:
global FLASH_ATTENTION
adapt_transformers_to_gaudi()
if speculate is not None:
set_speculate(speculate)
@ -449,393 +176,9 @@ def get_model(
model_type = config_dict["model_type"]
kv_cache_dtype = dtype
if FLASH_ATTENTION:
if model_type == DEEPSEEK_V2:
head_size = max(
config_dict.get("qk_nope_dim", 128)
+ config_dict.get("qk_rope_dim", 64),
config_dict.get("v_head_dim", 128),
)
return FlashCausalLM(
model_id=model_id,
model_class=FlashDeepseekV2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
default_dtype=torch.bfloat16,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DeepseekV2Config,
head_size=head_size,
)
elif model_type == DEEPSEEK_V3:
head_size = max(
config_dict.get("qk_nope_dim", 128)
+ config_dict.get("qk_rope_dim", 64),
config_dict.get("v_head_dim", 128),
)
return FlashCausalLM(
model_id=model_id,
model_class=FlashDeepseekV3ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
default_dtype=torch.bfloat16,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DeepseekV3Config,
head_size=head_size,
)
elif (
model_type == GPT_BIGCODE
or model_type == GPT2
and model_id.startswith("bigcode/")
):
return FlashCausalLM(
model_id=model_id,
model_class=FlashSantacoderForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
num_kv_heads=1,
)
elif model_type == GPT2:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGPT2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GPTJ:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGPTJForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GPT_NEOX:
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
GPTNeoXConfig,
)
return FlashCausalLM(
model_id=model_id,
model_class=FlashGPTNeoXForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=GPTNeoXConfig,
)
elif model_type == PHI:
return FlashCausalLM(
model_id=model_id,
model_class=FlashPhiForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == PHI_MOE:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
config_class=PhiMoEConfig,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == BAICHUAN:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GEMMA:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemmaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GEMMA2:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemma2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == COHERE:
return FlashCausalLM(
model_id=model_id,
model_class=FlashCohereForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == DBRX:
return FlashCausalLM(
model_id=model_id,
model_class=FlashDbrxForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Dbrx works better in bfloat16.
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DbrxConfig,
)
elif (
model_type in ["RefinedWeb", "RefinedWebModel", FALCON]
and not sharded
and not config_dict.get("alibi", False)
):
return FlashCausalLM(
model_id=model_id,
model_class=FlashRWForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"],
},
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=RWConfig,
)
elif model_type == MISTRAL:
return FlashCausalLM(
model_id=model_id,
model_class=FlashMistralForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == MIXTRAL:
return FlashCausalLM(
model_id=model_id,
model_class=FlashMixtralForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == STARCODER2:
return FlashCausalLM(
model_id=model_id,
model_class=FlashStarcoder2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == QWEN2:
return FlashCausalLM(
model_id=model_id,
model_class=Qwen2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == QWEN2_VL:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Qwen2VLForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == QWEN2_5_VL:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Qwen2_5VLForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=Qwen2_5_VLConfig,
processor_class=Qwen2_5_VLProcessor,
)
elif model_type == MLLAMA:
return FlashMllamaCausalLM(
model_id=model_id,
model_class=FlashMllamaForConditionalGeneration,
batch_class=FlashMllamaCausalLMBatch,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == IDEFICS2:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Idefics2ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
)
elif model_type == IDEFICS3:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Idefics3ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
processor_kwargs={"size": {"longest_edge": 1456}},
)
elif model_type == PALIGEMMA:
return FlashVlmCausalLM(
model_id=model_id,
model_class=PaliGemmaForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
batch_class=PaliGemmaBatch,
)
elif model_type == LLAVA_NEXT:
return FlashVlmCausalLM(
model_class=FlashLlavaNextForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.custom_modeling.mllama import (
MllamaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
adapt_transformers_to_gaudi()
if SDP_ON_BF16 == 1:
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
if model_type == "gpt_bigcode":
return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
if model_type == "bloom":
return BLOOM(
model_id=model_id,
@ -856,17 +199,6 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == "mllama":
return VlmCausalLM(
model_class=MllamaForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=None,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
model_id,

View File

@ -57,7 +57,8 @@ MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 2048))
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 256))
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
BATCH_SIZE_EXPONENT_BASE = int(os.environ.get("BATCH_SIZE_EXPONENT_BASE", 2))
BATCH_BUCKET_SIZE = int(os.environ.get("BATCH_BUCKET_SIZE", 8))
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get("PREFILL_BATCH_BUCKET_SIZE", 2))
MAX_BATCH_SIZE = (
int(os.environ.get("MAX_BATCH_SIZE"))
if os.environ.get("MAX_BATCH_SIZE") is not None
@ -73,16 +74,10 @@ def torch_compile_for_eager(func):
)
def round_up_seq(number, k):
def round_up(number, k):
return (number + k - 1) // k * k
def round_up_batch(number):
return BATCH_SIZE_EXPONENT_BASE ** (
math.ceil(math.log(number, BATCH_SIZE_EXPONENT_BASE))
)
def to_tensor_indices(indices, device):
return torch.tensor(indices, dtype=torch.long, device=device)
@ -404,7 +399,7 @@ class CausalLMBatch(Batch):
total_requests = sum(len(b) for b in batches)
new_bs = total_requests
new_bs = round_up_batch(total_requests)
new_bs = round_up(total_requests, BATCH_BUCKET_SIZE)
batch_id = batches[0].batch_id
device = batches[0].input_ids.device
@ -545,7 +540,7 @@ class CausalLMBatch(Batch):
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
# this means that we cannot shift inputs to the left after a long input sequence
# was filtered out
new_bs = round_up_batch(len(requests))
new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE)
missing_inputs = new_bs - len(inputs)
dummy_inputs = ["?"] * missing_inputs
parameters = [r.parameters for r in pb.requests]
@ -577,7 +572,7 @@ class CausalLMBatch(Batch):
assert (
PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length
), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
rounded_seq_len = round_up_seq(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
if rounded_seq_len <= max_input_length:
bucket_size = rounded_seq_len - 1
else:
@ -709,9 +704,6 @@ class CausalLM(Model):
htorch.core.hpu_set_env()
if world_size > 1:
os.environ.setdefault(
"DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1"
)
model = self.get_deepspeed_model(model_id, dtype, revision)
model = hq_env.prepare_model_for_quantization(model)
else:
@ -1073,10 +1065,10 @@ class CausalLM(Model):
if (
self.enable_hpu_graph
and self.limit_hpu_graph
and round_up_batch(batch.batch_size) != self.prev_bs
and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs
):
self.model.clear_cache()
self.prev_bs = round_up_batch(batch.batch_size)
self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE)
dbg_trace(
scenario,
f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}",
@ -1330,14 +1322,15 @@ class CausalLM(Model):
# Warmup prefill batch_size
max_input_tokens = request.max_input_tokens
max_exp = math.ceil(math.log(max_prefill_batch_size, BATCH_SIZE_EXPONENT_BASE))
prefill_batch_size_list = [
BATCH_SIZE_EXPONENT_BASE**exp
for exp in range(
0,
max_exp + 1,
batch
for batch in range(
PREFILL_BATCH_BUCKET_SIZE,
max_prefill_batch_size,
PREFILL_BATCH_BUCKET_SIZE,
)
]
prefill_batch_size_list.append(max_prefill_batch_size)
prefill_seqlen_list = [
seq
for seq in range(
@ -1374,10 +1367,12 @@ class CausalLM(Model):
)
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
max_exp = math.ceil(math.log(max_decode_batch_size, BATCH_SIZE_EXPONENT_BASE))
max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE)
decode_batch_size_list = [
BATCH_SIZE_EXPONENT_BASE**exp for exp in range(0, max_exp + 1)
i
for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE)
]
decode_batch_size_list.append(max_decode_batch_size)
decode_batch_size_list.sort(reverse=True)
try:

View File

@ -377,7 +377,7 @@ class BloomAttention(nn.Module):
past_value.view(-1, *past_value.shape[-2:]),
)
if CUSTOM_KERNELS_ENABLED and attention_mask.shape[-1] < 4096:
if CUSTOM_KERNELS_ENABLED:
assert self.training is False, "Only foward pass was implemented"
assert (
attention_mask.shape[-1] < 4096
@ -580,7 +580,7 @@ class BloomPreTrainedModel(PreTrainedModel):
@staticmethod
def _convert_to_bloom_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]],
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))

View File

@ -28,10 +28,10 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -39,6 +39,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
@ -46,10 +47,11 @@ from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.utils.weights import UnquantizedWeight
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
apply_rotary_pos_emb,
)
if SYSTEM == "cuda":
import dropout_layer_norm
else:
dropout_layer_norm = None
class CohereRotary(PositionRotaryEmbedding):
@ -61,25 +63,38 @@ class CohereRotary(PositionRotaryEmbedding):
sin: torch.Tensor,
):
# Such controlflows may add some overhead.
num_tokens = query.shape[0]
head_size = query.shape[-1]
rope_mode = RotaryPosEmbeddingMode.PAIRWISE
sin = torch.repeat_interleave(sin, 2, dim=-1)
cos = torch.repeat_interleave(cos, 2, dim=-1)
rotary_dim = cos.shape[-1]
query_shape = query.shape
query = query.view(num_tokens, -1, head_size)
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
if SYSTEM == "cuda":
import rotary_emb
key_shape = key.shape
key = key.view(num_tokens, -1, head_size)
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
q1 = query[..., ::2]
q2 = query[..., 1::2]
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
k1 = key[..., ::2]
k2 = key[..., 1::2]
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif SYSTEM == "rocm":
from vllm._C import ops
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
head_size = query.shape[-1]
# Inplace operation, updating query and key.
ops.rotary_embedding(query, key, head_size, cos, sin, False)
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
ipex.llm.functional.rotary_embedding(
query, key, sin, cos, query.size(-1), False
)
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)
class CohereLayerNorm(nn.Module):
@ -92,18 +107,49 @@ class CohereLayerNorm(nn.Module):
self.eps = eps
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(
if hidden_states.shape[-1] > 8192 or SYSTEM != "cuda":
hidden_states = hidden_states.reshape(
-1, self.weight.shape[0], self.weight.shape[1]
)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
hidden_states_minus_mean = hidden_states - mean
variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps)
hidden_states = self.weight.to(torch.float32) * hidden_states
hidden_states = hidden_states.view(-1, self.weight.shape[1])
return hidden_states.to(input_dtype)
(
hidden_states,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
None,
self.ones,
None,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
# Required to apply one weight matrix per head
hidden_states = hidden_states.view(
-1, self.weight.shape[0], self.weight.shape[1]
)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
hidden_states_minus_mean = hidden_states - mean
variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps)
hidden_states = self.weight.to(torch.float32) * hidden_states
hidden_states = self.weight * hidden_states
hidden_states = hidden_states.view(-1, self.weight.shape[1])
return hidden_states.to(input_dtype)
return hidden_states
def load_attention(config, prefix, weights):
@ -183,7 +229,6 @@ class FlashCohereAttention(torch.nn.Module):
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.use_qk_norm = config.use_qk_norm
if self.use_qk_norm:
@ -219,9 +264,10 @@ class FlashCohereAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
qkv = self.query_key_value(hidden_states)
query, key, value = qkv.split(
@ -245,35 +291,30 @@ class FlashCohereAttention(torch.nn.Module):
self.rotary_emb(query, key, cos, sin)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
return self.o_proj(
@ -345,9 +386,10 @@ class FlashCohereLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -358,9 +400,10 @@ class FlashCohereLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
mlp_output = self.mlp(normed_hidden_states)
@ -409,15 +452,18 @@ class FlashCohereModel(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: torch.Tensor,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
@ -429,9 +475,10 @@ class FlashCohereModel(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -469,9 +516,11 @@ class FlashCohereForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -480,9 +529,10 @@ class FlashCohereForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -20,14 +20,17 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM != "ipex":
from vllm.model_executor.layers.fused_moe import fused_moe
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
HPUPagedAttentionMetadata,
PREFILL_IN_KV_CACHE,
)
from text_generation_server.layers import (
FastLinear,
@ -43,7 +46,6 @@ from text_generation_server.layers.rotary import (
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
from vllm_hpu_extension.ops import DynamicFusedMOE
class DbrxAttentionConfig(PretrainedConfig):
@ -288,7 +290,6 @@ class DbrxAttention(torch.nn.Module):
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
@ -308,9 +309,10 @@ class DbrxAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
qkv = self.query_key_value(hidden_states)
if self.clip_qkv is not None:
@ -328,35 +330,30 @@ class DbrxAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -390,9 +387,10 @@ class DbrxNormAttentionNorm(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
normed_hidden_states, res = self.norm_1(hidden_states, residual)
@ -403,9 +401,10 @@ class DbrxNormAttentionNorm(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
# faster post attention rms norm
@ -483,15 +482,18 @@ class BlockSparseMoE(nn.Module):
self.process_group = weights.process_group
self.hpu_fused_moe = DynamicFusedMOE(self.num_experts)
for i in range(self.num_experts):
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.wv1[i])
self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.w2[i])
def forward(self, x: torch.Tensor) -> torch.Tensor:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(x)
out = self.hpu_fused_moe(x, router_logits, self.top_k)
out = fused_moe(
x,
self.wv1,
self.w2,
router_logits,
self.top_k,
renormalize=self.moe_normalize_expert_weights,
inplace=True,
)
# Reduce sum
if self.process_group.size() > 1:
@ -618,9 +620,10 @@ class DbrxLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
# Self Attention
attn_output, attn_res = self.attn(
@ -630,9 +633,10 @@ class DbrxLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
moe_output = self.moe(attn_output)
@ -673,15 +677,18 @@ class DbrxModel(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids)
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
@ -692,9 +699,10 @@ class DbrxModel(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -724,9 +732,11 @@ class FlashDbrxForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -735,9 +745,10 @@ class FlashDbrxForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -33,14 +33,21 @@ from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
HPUPagedAttentionMetadata,
reshape_and_cache,
)
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import Weights
if SYSTEM == "rocm":
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
class DeepseekV2Config(PretrainedConfig):
def __init__(
@ -225,8 +232,6 @@ class DeepseekV2Attention(torch.nn.Module):
),
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
)
@ -255,10 +260,11 @@ class DeepseekV2Attention(torch.nn.Module):
cos: torch.Tensor,
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache: KVCache,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
):
if self.q_lora_rank is None:
query = self.q_proj(hidden_states)
@ -315,35 +321,30 @@ class DeepseekV2Attention(torch.nn.Module):
value, (0, self.head_pad_size - self.value_head_size), value=0
)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
# Remove padding.
@ -386,11 +387,27 @@ class DeepseekV2MLP(nn.Module):
self.quantize = config.quantize
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
)
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
and hidden_states.dtype == torch.float16
and hidden_states.shape[0] == 1
and not self.quantize
):
out = torch.empty(
hidden_states.shape[0],
self.intermediate_size,
dtype=hidden_states.dtype,
device="cuda",
)
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
return self.down_proj(out, reduce=reduce)
else:
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
)
class DeepseekV2MoE(nn.Module):
@ -503,9 +520,10 @@ class DeepseekV2Layer(nn.Module):
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache,
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
):
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -516,9 +534,10 @@ class DeepseekV2Layer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
# faster post attention rms norm
@ -564,15 +583,18 @@ class DeepseekV2Model(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
@ -583,9 +605,10 @@ class DeepseekV2Model(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -612,9 +635,11 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -623,9 +648,10 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -1,642 +0,0 @@
# coding=utf-8
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Type
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from text_generation_server.layers import (
FastLinear,
SpeculativeHead,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
get_linear,
)
from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.weights import Weights
class DeepseekV3Config(PretrainedConfig):
def __init__(
self,
vocab_size=102400,
hidden_size=4096,
intermediate_size=11008,
moe_intermediate_size=1407,
num_hidden_layers=30,
num_attention_heads=32,
num_key_value_heads=32,
n_shared_experts=2,
n_routed_experts=160,
ep_size=1,
routed_scaling_factor=1.0,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
topk_method="gready",
n_group=8,
topk_group=3,
num_experts_per_tok=6,
moe_layer_freq=1,
first_k_dense_replace=0,
norm_topk_prob=False,
scoring_func="softmax",
aux_loss_alpha=0.001,
seq_aux=True,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=100000,
eos_token_id=100001,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
self.aux_loss_alpha = aux_loss_alpha
self.seq_aux = seq_aux
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
if tie_word_embeddings:
raise ValueError(
"tie_word_embeddings is not supported for Deepseek V2 models."
)
if ep_size != 1:
raise ValueError(
f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}"
)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class DeepseekV3Attention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights: Weights,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.kv_lora_rank = config.kv_lora_rank
self.q_lora_rank = config.q_lora_rank
self.qk_nope_head_dim = config.qk_nope_head_dim
self.qk_rope_head_dim = config.qk_rope_head_dim
self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
self.value_head_size = config.v_head_dim
self.head_pad_size = max(self.head_size, self.value_head_size)
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.qk_rope_head_dim,
base=config.rope_theta,
device=weights.device,
)
mscale = get_mscale(
self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
)
self.softmax_scale = self.head_size**-0.5 * mscale * mscale
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
if self.q_lora_rank is None:
self.q_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_proj",
weights=weights,
bias=config.attention_bias,
)
else:
self.q_a_proj = get_linear(
weight=weights.get_weights(f"{prefix}.q_a_proj"),
bias=(
weights.get_tensor(f"{prefix}.q_a_proj.bias")
if config.attention_bias
else None
),
)
self.q_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.q_a_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.q_b_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_b_proj",
weights=weights,
bias=config.attention_bias,
)
self.kv_a_proj_with_mqa = get_linear(
weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"),
bias=(
weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias")
if config.attention_bias
else None
),
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.kv_b_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.kv_b_proj",
weights=weights,
bias=config.attention_bias,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache: KVCache,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
if self.q_lora_rank is None:
query = self.q_proj(hidden_states)
else:
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
query = query.view(-1, self.num_heads, self.head_size)
_, query_pe = torch.split(
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, key_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
)
key_nope, value = torch.split(
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
)
batch_size, heads, head_dim = query_pe.shape
query_pe = (
query_pe.view(batch_size, heads, head_dim // 2, 2)
.transpose(2, 3)
.reshape(batch_size, heads, head_dim)
)
batch_size, heads, head_dim = key_pe.shape
key_pe = (
key_pe.view(batch_size, heads, head_dim // 2, 2)
.transpose(2, 3)
.reshape(batch_size, heads, head_dim)
)
self.rotary_emb(query_pe, key_pe, cos, sin)
query[..., self.qk_nope_head_dim :] = query_pe
key = torch.empty_like(query)
key[..., : self.qk_nope_head_dim] = key_nope
key[..., self.qk_nope_head_dim :] = key_pe
# We need to pad the heads because Flash Attention does not support
# qk and v with different head sizes.
query = torch.nn.functional.pad(
query, (0, self.head_pad_size - self.head_size), value=0
)
key = torch.nn.functional.pad(
key, (0, self.head_pad_size - self.head_size), value=0
)
value = torch.nn.functional.pad(
value, (0, self.head_pad_size - self.value_head_size), value=0
)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
# Remove padding.
attn_output = attn_output[..., : self.value_head_size]
return self.o_proj(
attn_output.reshape(-1, self.num_heads * self.value_head_size)
)
class DeepseekV3MLP(nn.Module):
def __init__(self, prefix: str, config, weights, intermediate_size: int):
super().__init__()
self.hidden_act = config.hidden_act
if self.hidden_act != "silu":
# Bail out because MoE only supports silu.
raise NotImplementedError(
"Currently only `silu` is supported as an activation for Deepseek V2."
)
self.act = ACT2FN[self.hidden_act]
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.intermediate_size = intermediate_size // weights.process_group.size()
# TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
)
class DeepseekV3MoE(nn.Module):
def __init__(
self,
prefix,
config: DeepseekV3Config,
moe_layer_cls: Type[MoELayer],
weights,
):
super().__init__()
self.hidden_dim = config.hidden_size
self.moe_intermediate_size = (
config.moe_intermediate_size // weights.process_group.size()
)
self.routed_scaling_factor = config.routed_scaling_factor
# Gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
if config.topk_method == "noaux_tc":
self.gate.e_score_correction_bias = torch.zeros(
config.n_routed_experts, device=weights.device
)
else:
self.gate.e_score_correction_bias = None
self.moe_layer = moe_layer_cls(
prefix=f"{prefix}.experts",
n_experts=config.n_routed_experts,
n_expert_group=config.n_group,
renormalize=config.norm_topk_prob,
topk=config.num_experts_per_tok,
topk_group=config.topk_group,
weights=weights,
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
)
assert isinstance(self.moe_layer, MoELayer)
if config.n_shared_experts is not None:
self.shared_experts = DeepseekV3MLP(
prefix=f"{prefix}.shared_experts",
config=config,
weights=weights,
intermediate_size=config.moe_intermediate_size
* config.n_shared_experts,
)
else:
self.shared_experts = None
self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.shared_experts is not None:
shared_output = self.shared_experts(x, reduce=False)
else:
shared_output = None
router_logits = self.gate(x)
out = self.moe_layer(x, gating_output=router_logits)
if shared_output is not None:
out = out + shared_output
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out.view(*x.shape)
class DeepseekV3Layer(nn.Module):
def __init__(self, prefix, layer_id, config, weights):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = DeepseekV3Attention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
)
if (
config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace
and layer_id % config.moe_layer_freq == 0
):
moe_layer_cls = (
SparseMoELayer
if SparseMoELayer.is_supported(weights)
else DenseMoELayer
)
self.mlp = DeepseekV3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
else:
self.mlp = DeepseekV3MLP(
prefix=f"{prefix}.mlp",
config=config,
weights=weights,
intermediate_size=config.intermediate_size,
)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
# faster post attention rms norm
normed_attn_res_output, residual = self.post_attention_layernorm(
attn_output, residual
)
output = self.mlp(normed_attn_res_output)
return output, residual
class DeepseekV3Model(torch.nn.Module):
def __init__(self, prefix: str, config, weights: Weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
DeepseekV3Layer(
prefix,
layer_id,
config,
weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
hpu_attention_meta,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashDeepseekV3ForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights: Weights):
super().__init__()
self.model = DeepseekV3Model(
"model" if not prefix else f"{prefix}.model", config, weights
)
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -28,8 +28,8 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -40,7 +40,7 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
@ -208,7 +208,6 @@ class FlashGemma2Attention(torch.nn.Module):
],
process_group=weights.process_group,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load(
config,
@ -235,10 +234,11 @@ class FlashGemma2Attention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
@ -253,24 +253,19 @@ class FlashGemma2Attention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
causal=self.causal,
window_size_left=self.window_size,
softcap=self.softcap,
)
@ -278,13 +273,14 @@ class FlashGemma2Attention(torch.nn.Module):
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
max_s,
softcap=self.softcap,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.o_proj(
@ -394,10 +390,11 @@ class FlashGemma2Layer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
hpu_attention_meta,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -408,10 +405,11 @@ class FlashGemma2Layer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
hpu_attention_meta,
)
# faster post attention rms norm
@ -460,16 +458,19 @@ class FlashGemma2Model(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
adapter_data: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
@ -480,10 +481,11 @@ class FlashGemma2Model(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
max_s,
adapter_data,
hpu_attention_meta,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -527,9 +529,11 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -539,10 +543,11 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -28,8 +28,9 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
HPUPagedAttentionMetadata,
PREFILL_IN_KV_CACHE,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -38,7 +39,6 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
@ -187,7 +187,6 @@ class FlashGemmaAttention(torch.nn.Module):
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
@ -207,9 +206,10 @@ class FlashGemmaAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
@ -224,36 +224,31 @@ class FlashGemmaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
causal=self.causal,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -322,9 +317,10 @@ class FlashGemmaLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -335,9 +331,10 @@ class FlashGemmaLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
# faster post attention rms norm
@ -382,16 +379,18 @@ class FlashGemmaModel(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
adapter_data: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
) -> torch.Tensor:
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
@ -402,9 +401,10 @@ class FlashGemmaModel(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -446,9 +446,11 @@ class FlashGemmaForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -458,10 +460,10 @@ class FlashGemmaForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -24,11 +24,12 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -37,7 +38,6 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
def load_qkv(config, prefix: str, weights, head_size, num_heads):
@ -47,6 +47,10 @@ def load_qkv(config, prefix: str, weights, head_size, num_heads):
prefix,
weights,
)
elif config.quantize == "marlin":
raise RuntimeError(
"GPT-2 models with marlin quantization are not yet supported"
)
else:
return _load_qkv(config, prefix, weights, head_size, num_heads)
@ -191,7 +195,6 @@ class FlashGPT2Attention(torch.nn.Module):
head_size=self.head_size,
num_heads=self.num_heads,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = load_row(
config,
@ -209,9 +212,10 @@ class FlashGPT2Attention(torch.nn.Module):
hidden_states,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
query, key, value = self.query_key_value(hidden_states).split(
self.head_size * self.num_heads, dim=1
@ -220,35 +224,30 @@ class FlashGPT2Attention(torch.nn.Module):
key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -314,9 +313,10 @@ class FlashGPT2Layer(nn.Module):
residual,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
@ -326,9 +326,10 @@ class FlashGPT2Layer(nn.Module):
hidden_states,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
hidden_states = attn_output + residual
@ -378,9 +379,12 @@ class FlashGPT2Model(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = inputs_embeds
@ -391,9 +395,10 @@ class FlashGPT2Model(torch.nn.Module):
residual,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
hidden_states = self.norm(hidden_states)
@ -427,9 +432,11 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -441,9 +448,12 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta=hpu_attention_meta,
max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -24,12 +24,12 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -38,16 +38,13 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
apply_rotary_pos_emb,
)
def load_attention(config, prefix: str, weights):
@ -81,25 +78,39 @@ class GPTJRotary(PositionRotaryEmbedding):
cos: torch.Tensor,
sin: torch.Tensor,
):
num_tokens = query.shape[0]
head_size = query.shape[-1]
rope_mode = RotaryPosEmbeddingMode.PAIRWISE
sin = torch.repeat_interleave(sin, 2, dim=-1)
cos = torch.repeat_interleave(cos, 2, dim=-1)
rotary_dim = cos.shape[-1]
query_shape = query.shape
query = query.view(num_tokens, -1, head_size)
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
# Such controlflows may add some overhead.
if SYSTEM == "cuda":
import rotary_emb
key_shape = key.shape
key = key.view(num_tokens, -1, head_size)
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
q1 = query[..., ::2]
q2 = query[..., 1::2]
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
k1 = key[..., ::2]
k2 = key[..., 1::2]
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif SYSTEM == "rocm":
from vllm._C import ops
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
head_size = query.shape[-1]
# Inplace operation, updating query and key.
ops.rotary_embedding(query, key, head_size, cos, sin, False)
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
ipex.llm.functional.rotary_embedding(
query, key, sin, cos, query.size(-1), False
)
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)
class FlashGPTJAttention(torch.nn.Module):
@ -129,7 +140,6 @@ class FlashGPTJAttention(torch.nn.Module):
prefix=prefix,
weights=weights,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = load_row(
config,
@ -156,9 +166,10 @@ class FlashGPTJAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
query, key, value = self.query_key_value(hidden_states).split(
self.head_size * self.num_heads, dim=1
@ -175,35 +186,30 @@ class FlashGPTJAttention(torch.nn.Module):
else:
self.rotary_emb(query, key, cos, sin)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -261,9 +267,10 @@ class FlashGPTJLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention
@ -273,9 +280,10 @@ class FlashGPTJLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
feed_forward_hidden_states = self.mlp(hidden_states)
@ -319,15 +327,19 @@ class FlashGPTJModel(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
@ -338,9 +350,10 @@ class FlashGPTJModel(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
hidden_states, _ = self.ln_f(hidden_states, residual)
@ -368,9 +381,11 @@ class FlashGPTJForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -379,9 +394,11 @@ class FlashGPTJForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta=hpu_attention_meta,
max_s,
prefill_cache_indices=prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -27,16 +27,14 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from text_generation_server.layers.attention import (
KVCache,
get_kv_scales,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -59,6 +57,15 @@ from text_generation_server.utils.weights import (
)
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
if SYSTEM != "ipex":
pass
if SYSTEM == "rocm":
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
def load_attention(config, prefix: str, weights, layer_id):
# Only defined in granite.
@ -150,10 +157,7 @@ class FlashLlamaAttention(torch.nn.Module):
device=weights.device,
)
# `config.attention_multiplier` is used in Granite
self.softmax_scale = getattr(
config, "attention_multiplier", self.head_size**-0.5
)
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
@ -173,13 +177,11 @@ class FlashLlamaAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights, index)
self.index = index
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=getattr(config, "attention_bias", False),
bias=False,
)
self.o_proj = TensorParallelAdapterRowLinear.load(
@ -200,11 +202,12 @@ class FlashLlamaAttention(torch.nn.Module):
cos,
sin,
cu_seqlen_prefill,
kv_cache: KVCache,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
@ -219,35 +222,30 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_scales=self.kv_scales,
kv_cache=kv_cache,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
return self.o_proj(
@ -365,11 +363,31 @@ class LlamaMLP(nn.Module):
self.hidden_size = config.hidden_size
def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
and hidden_states.dtype == torch.float16
and hidden_states.shape[0] == 1
and not self.quantize
and self.hidden_size
!= 16384 # TODO: Temporary workaround for `LLMM_Silu` kernel not working with LLama3.1 405B; needs refactoring once fixed.
):
out = torch.empty(
hidden_states.shape[0],
self.intermediate_size,
dtype=hidden_states.dtype,
device="cuda",
)
_custom_C.LLMM_Silu(
self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8
)
return self.down_proj(out, adapter_data)
else:
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class FlashLlamaLayer(nn.Module):
@ -390,7 +408,7 @@ class FlashLlamaLayer(nn.Module):
if SparseMoELayer.is_supported(weights)
else DenseMoELayer
)
self.mlp = Phi3MoE(
self.dense = Phi3MoE(
f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
)
# with moe the layernorms are are not rmsnorms and they have bias
@ -405,7 +423,7 @@ class FlashLlamaLayer(nn.Module):
eps=config.rms_norm_eps,
)
else:
self.mlp = LlamaMLP(
self.dense = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
)
self.input_layernorm = FastRMSNorm.load(
@ -419,11 +437,6 @@ class FlashLlamaLayer(nn.Module):
eps=config.rms_norm_eps,
)
# Used in Granite
# This could eventually be baked into the weights like we do for the embeddings/lm_head
# but this would mean modifying the lora code
self.residual_multiplier = getattr(config, "residual_multiplier", None)
def forward(
self,
hidden_states,
@ -432,11 +445,12 @@ class FlashLlamaLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
cross_attention_states,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -447,21 +461,19 @@ class FlashLlamaLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
hpu_attention_meta=hpu_attention_meta,
)
if self.residual_multiplier is not None:
attn_output *= self.residual_multiplier
# faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
)
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
if self.residual_multiplier is not None:
mlp_output *= self.residual_multiplier
mlp_output = self.dense(normed_attn_res_output, adapter_data)
return mlp_output, attn_res
@ -481,7 +493,9 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append(
FlashLlamaLayer(
index=0,
prefix=f"{prefix}.layers.0",
prefix=(
"model.layers.0" if not prefix else f"{prefix}.model.layers.0"
),
config=config,
weights=weights,
)
@ -490,14 +504,18 @@ class FlashLlamaModel(torch.nn.Module):
# Skip first and last layers
for layer_id in range(1, config.num_hidden_layers - 1):
if layer_id in self.cross_attention_layers:
from text_generation_server.models.custom_modeling.flash_mllama import (
from text_generation_server.models.custom_modeling.mllama import (
FlashLlamaCrossLayer,
)
self.layers.append(
FlashLlamaCrossLayer(
index=layer_id,
prefix=(f"{prefix}.layers.{layer_id}"),
prefix=(
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}.model.layers.{layer_id}"
),
config=config,
weights=weights,
)
@ -506,7 +524,11 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append(
FlashLlamaLayer(
index=layer_id,
prefix=(f"{prefix}.layers.{layer_id}"),
prefix=(
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}.model.layers.{layer_id}"
),
config=config,
weights=weights,
)
@ -517,14 +539,18 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append(
FlashLlamaLayer(
index=last_layer_id,
prefix=(f"{prefix}.layers.{last_layer_id}"),
prefix=(
f"model.layers.{last_layer_id}"
if not prefix
else f"{prefix}.model.layers.{last_layer_id}"
),
config=config,
weights=weights,
)
)
self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm",
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
weights=weights,
eps=config.rms_norm_eps,
)
@ -541,17 +567,22 @@ class FlashLlamaModel(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
adapter_data,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
cross_attention_states=None,
) -> torch.Tensor:
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
@ -562,11 +593,12 @@ class FlashLlamaModel(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
max_s,
adapter_data,
cross_attention_states,
hpu_attention_meta=hpu_attention_meta,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -575,60 +607,42 @@ class FlashLlamaModel(torch.nn.Module):
class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights, name=None):
if name is None:
name = "model"
def __init__(self, prefix: str, config, weights):
super().__init__()
with no_fp8(weights):
self.embed_tokens = TensorParallelEmbedding(
prefix=(
f"{name}.embed_tokens"
"model.embed_tokens"
if not prefix
else f"{prefix}.{name}.embed_tokens"
else f"{prefix}.model.embed_tokens"
),
weights=weights,
)
self.model = FlashLlamaModel(
prefix=name if not prefix else f"{prefix}.{name}",
config=config,
weights=weights,
)
self.model = FlashLlamaModel(prefix, config, weights)
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
else:
suffix = "lm_head"
# Used in Granite
embedding_multiplier = getattr(config, "embedding_multiplier", None)
if embedding_multiplier is not None:
self.embed_tokens.weight.data *= embedding_multiplier
prefix = suffix if not prefix or name != "model" else f"{prefix}.{suffix}"
with no_fp8(weights):
self.lm_head = SpeculativeHead.load(
config,
prefix,
weights,
prefix=suffix if not prefix else f"{prefix}.{suffix}",
weights=weights,
)
# Used in Granite
self.logits_scaling = getattr(config, "logits_scaling", None)
if self.logits_scaling is not None and self.lm_head.head is not None:
try:
# Scale the weights directly
self.lm_head.head.linear.weight.data /= self.logits_scaling
self.logits_scaled = True
except Exception:
self.logits_scaled = False
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states=None,
@ -639,20 +653,16 @@ class FlashLlamaForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
adapter_data=adapter_data,
cross_attention_states=cross_attention_states,
hpu_attention_meta=hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
# Used in Granite
if self.logits_scaling is not None and not self.logits_scaled:
logits /= self.logits_scaling
if speculative_logits is not None:
speculative_logits /= self.logits_scaling
return logits, speculative_logits

View File

@ -1,285 +0,0 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Llava-NeXT model."""
from typing import List, Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (`tuple`):
The size of the input image in the format (height, width).
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (height, width).
"""
if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists")
height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (`torch.Tensor`):
The image tensor, assumed to be of shape (num_channels, height, width).
original_size (`tuple`):
The original size of the image (height, width).
Returns:
`torch.Tensor`: The unpadded image tensor.
"""
original_height, original_width = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
return unpadded_tensor
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
class LlavaNextMultiModalProjector(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.linear_1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class FlashLlavaNextForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
vision_config = config.vision_config
# Instead of selecting in hidden_states[-2].
# Instead compute only the n -2 + 1 layers and don't pool
if config.vision_feature_layer < 0:
vision_config.num_hidden_layers += config.vision_feature_layer + 1
else:
vision_config.num_hidden_layers = config.vision_feature_layer + 1
self.vision_tower = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config,
weights=weights,
)
self.multi_modal_projector = LlavaNextMultiModalProjector(
prefix="multi_modal_projector", config=config, weights=weights
)
self.image_newline = weights.get_tensor("image_newline")
self.vocab_size = config.text_config.vocab_size
self.config = config
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
self.text_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,
weights=weights,
)
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
def _merge_input_ids_with_image_features(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
mask = torch.where(input_ids == self.config.image_token_index)
# Let's pray we have enabled enough slots !
try:
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
except Exception as e:
raise RuntimeError(
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused for this model
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
# 1. Extract the input embeddings
# 2. Merge text and images
num_images, num_patches, channels, height, width = pixel_values.shape
pixel_values = pixel_values.view(
num_images * num_patches, channels, height, width
)
image_features = self.vision_tower(pixel_values)
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
# Already done within the clip model
selected_image_feature = image_features.last_hidden_state
if self.config.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.config.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise RuntimeError(
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
)
image_features = self.multi_modal_projector(selected_image_feature)
# split up image_features for each of the individual images
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [num_patches] * num_images
image_features = torch.split(image_features, split_sizes, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
if height * width != base_image_feature.shape[0]:
raise ValueError(
"The number of patches is not consistent with the image size."
)
# Dimensions are intentionally swapped to be bug-compatible with
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature, self.image_newline[None]), dim=0
)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_features
)
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
adapter_data=adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -26,12 +26,12 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -41,12 +41,20 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
if SYSTEM == "rocm":
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
class MistralConfig(PretrainedConfig):
model_type = "mistral"
@ -152,7 +160,6 @@ class MistralAttention(torch.nn.Module):
],
process_group=weights.process_group,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load(
config,
@ -178,10 +185,12 @@ class MistralAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
@ -196,36 +205,38 @@ class MistralAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
window_size_left=self.max_past,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
return self.o_proj(
@ -289,11 +300,29 @@ class MistralMLP(nn.Module):
self.quantize = config.quantize
def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
and hidden_states.dtype == torch.float16
and hidden_states.shape[0] == 1
and not self.quantize
):
out = torch.empty(
hidden_states.shape[0],
self.intermediate_size,
dtype=hidden_states.dtype,
device="cuda",
)
_custom_C.LLMM_Silu(
self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8
)
return self.down_proj(out, adapter_data)
else:
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class MistralLayer(nn.Module):
@ -326,10 +355,12 @@ class MistralLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -340,10 +371,12 @@ class MistralLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
)
# faster post attention rms norm
@ -390,15 +423,20 @@ class MistralModel(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
adapter_data: Optional[torch.Tensor] = None,
):
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, true_max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
@ -409,10 +447,12 @@ class MistralModel(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -458,21 +498,35 @@ class FlashMistralForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
inputs_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
true_max_s,
prefill_cache_indices,
adapter_data,
)
if lm_head_indices is not None:

View File

@ -37,9 +37,9 @@ from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
HPUPagedAttentionMetadata,
reshape_and_cache,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding
@ -215,7 +215,6 @@ class MixtralAttention(torch.nn.Module):
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
@ -235,9 +234,11 @@ class MixtralAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
prefill_cache_indices,
):
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
@ -252,36 +253,38 @@ class MixtralAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
window_size_left=self.max_past,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -375,9 +378,11 @@ class MixtralLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
prefill_cache_indices,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -388,9 +393,11 @@ class MixtralLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
prefill_cache_indices,
)
# faster post attention rms norm
@ -441,15 +448,20 @@ class MixtralModel(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, true_max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
@ -460,9 +472,11 @@ class MixtralModel(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
prefill_cache_indices,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -493,21 +507,34 @@ class FlashMixtralForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
true_max_s,
prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -1,986 +0,0 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Mllama model."""
from typing import Optional, Tuple, List
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
import torch.nn.functional as F
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
FastLinear,
)
from text_generation_server.layers.attention import (
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA
def _prepare_aspect_ratio_attention_mask(
aspect_ratio_mask: torch.Tensor,
num_patches: int,
target_length: int,
dtype: torch.dtype,
) -> torch.Tensor:
# Expand aspect ratio mask to target_length
batch_size, max_num_tiles = aspect_ratio_mask.shape
attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype)
attention_mask = attention_mask.repeat(1, 1, target_length, 1)
# Mask padding patches
pad_patches = target_length - num_patches
attention_mask[:, :, -pad_patches:] = 0
# Invert the mask (0 -> 1, 1 -> 0)
attention_mask = 1 - attention_mask
# Reshape to 2D and create 4D attention mask
# (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
attention_mask = attention_mask.reshape(
batch_size, max_num_tiles * target_length, 1
)
attention_mask = (
attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min
)
attention_mask = attention_mask.unsqueeze(1)
return attention_mask
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(
target_length, device=device
) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = (
causal_mask.clone()
) # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = (
causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)
return causal_mask
def _prepare_cross_attention_mask(
cross_attention_mask: torch.Tensor,
num_vision_tokens: int,
dtype: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
# reshape so it can be used by attn module
batch_size, text_total_length, *_ = cross_attention_mask.shape
cross_attention_mask = cross_attention_mask.repeat_interleave(
num_vision_tokens, dim=3
)
cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1)
cross_attention_mask = cross_attention_mask.unsqueeze(1)
# invert the mask
inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype)
cross_attention_mask = inverted_cross_attn_mask.masked_fill(
inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min
)
# apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's
# last dimension contains negative infinity values, otherwise it's 1
negative_inf_value = torch.finfo(dtype).min
full_text_row_masked_out_mask = (
(cross_attention_mask != negative_inf_value)
.any(dim=-1)
.type_as(cross_attention_mask)[..., None]
)
cross_attention_mask *= full_text_row_masked_out_mask
return cross_attention_mask, full_text_row_masked_out_mask
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision
class MllamaVisionMLP(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class MllamaVisionSdpaAttention(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.head_dim = config.hidden_size // config.attention_heads
self.num_heads = config.attention_heads // weights.process_group.size()
self.qkv_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_state)
query, key, value = qkv.split(
[
self.head_dim * self.num_heads,
self.head_dim * self.num_heads,
self.head_dim * self.num_heads,
],
dim=2,
)
batch_size, q_seq_len, _ = query.shape
_, kv_seq_len, _ = key.shape
query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim)
key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn_output = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
output = self.o_proj(attn_output)
return output
class MllamaVisionEncoderLayer(nn.Module):
def __init__(self, *, prefix, config, weights, is_gated: bool):
super().__init__()
self.hidden_size = config.hidden_size
self.num_attention_heads = config.attention_heads
self.is_gated = is_gated
self.intermediate_size = config.intermediate_size
self.self_attn = MllamaVisionSdpaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.mlp = MllamaVisionMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights
)
self.input_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05
)
self.post_attention_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=1e-05
)
# there used to be an if else here, no code path
if is_gated:
self.gate_attn = nn.Parameter(
weights.get_tensor(f"{prefix}.gate_attn"), requires_grad=False
)
self.gate_ffn = nn.Parameter(
weights.get_tensor(f"{prefix}.gate_ffn"), requires_grad=False
)
def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
# Self Attention
residual = hidden_state
hidden_state = self.input_layernorm(hidden_state)
hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask)
gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()
hidden_state = residual + gate_attn * hidden_state
# Feed forward
residual = hidden_state
hidden_state = self.post_attention_layernorm(hidden_state)
hidden_state = self.mlp(hidden_state)
gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()
hidden_state = residual + gate_ffn * hidden_state
return hidden_state
class MllamaVisionEncoder(nn.Module):
def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int):
super().__init__()
self.config = config
self.layers = [
MllamaVisionEncoderLayer(
prefix=f"{prefix}.layers.{i}",
config=config,
weights=weights,
is_gated=is_gated,
)
for i in range(num_layers)
]
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
encoder_states = [hidden_states]
for encoder_layer in self.layers:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
)
hidden_states = layer_outputs
encoder_states.append(hidden_states)
return hidden_states, encoder_states
class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.max_num_tiles = config.max_num_tiles
self.hidden_size = config.hidden_size
self.max_aspect_ratio_id = config.max_aspect_ratio_id
self.embedding = TensorParallelEmbedding(
prefix=f"{prefix}.embedding", weights=weights
)
self.gate = nn.Parameter(
weights.get_tensor(f"{prefix}.gate"), requires_grad=False
)
def forward(
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
) -> torch.Tensor:
embeddings = self.embedding(aspect_ratio_ids)
embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)
# Always gated.
embeddings = embeddings * self.gate.tanh()
hidden_state = hidden_state + embeddings
return hidden_state
class MllamaPrecomputedPositionEmbedding(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.max_num_tiles = config.max_num_tiles
self.max_aspect_ratio_id = config.max_aspect_ratio_id
self.num_patches = (config.image_size // config.patch_size) ** 2 + 1
self.hidden_size = config.hidden_size
self.scale = config.hidden_size**-0.5
self.gate = nn.Parameter(
weights.get_tensor(f"{prefix}.gate"), requires_grad=False
)
# position embedding
embedding = nn.Parameter(
weights.get_tensor(f"{prefix}.embedding"), requires_grad=False
)
self.gated_position_embedding = (1 - self.gate.tanh()) * embedding
self.tile_embedding = TensorParallelEmbedding(
prefix=f"{prefix}.tile_embedding", weights=weights
)
def forward(
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
) -> torch.Tensor:
# position embeddings
hidden_state = hidden_state + self.gated_position_embedding.view(
1, 1, self.num_patches, self.hidden_size
)
# precomputed tile position embeddings
tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
batch_size = hidden_state.shape[0]
tile_position_embedding = tile_position_embedding.reshape(
batch_size, self.max_num_tiles, self.num_patches, self.hidden_size
)
gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding
hidden_state = hidden_state + gated_tile_position_embedding
return hidden_state
class MllamaVisionModel(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.image_size = config.image_size
self.patch_size = config.patch_size
self.max_num_tiles = config.max_num_tiles
self.hidden_size = config.hidden_size
self.num_channels = config.num_channels
self.intermediate_layers_indices = config.intermediate_layers_indices
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
self.scale = config.hidden_size**-0.5
self.dtype = weights.dtype
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.hidden_size,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
bias=False,
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
)
self.class_embedding = nn.Parameter(
weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False
)
self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(
prefix=f"{prefix}.gated_positional_embedding",
config=config,
weights=weights,
)
self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
prefix=f"{prefix}.pre_tile_positional_embedding",
config=config,
weights=weights,
)
self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
prefix=f"{prefix}.post_tile_positional_embedding",
config=config,
weights=weights,
)
## layer norms
self.layernorm_pre = nn.LayerNorm.load(
prefix=f"{prefix}.layernorm_pre",
weights=weights,
# torch default
eps=1e-05,
)
self.layernorm_post = nn.LayerNorm.load(
prefix=f"{prefix}.layernorm_post",
weights=weights,
# torch default
eps=1e-05,
)
## encoders
self.transformer = MllamaVisionEncoder(
prefix=f"{prefix}.transformer",
config=config,
weights=weights,
is_gated=False,
num_layers=config.num_hidden_layers,
)
self.global_transformer = MllamaVisionEncoder(
prefix=f"{prefix}.global_transformer",
config=config,
weights=weights,
is_gated=True,
num_layers=config.num_global_layers,
)
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
batch_size, _, hidden_size = hidden_state.shape
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
return hidden_state
def forward(
self,
pixel_values: torch.Tensor,
aspect_ratio_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
(
batch_size,
num_concurrent_media,
num_tiles,
num_channels,
height,
width,
) = pixel_values.shape
pixel_values = pixel_values.reshape(
batch_size * num_concurrent_media * num_tiles, num_channels, height, width
)
aspect_ratio_ids = aspect_ratio_ids.reshape(
batch_size * num_concurrent_media, -1
)
# patch embedding
patch_embeds = self.patch_embedding(pixel_values)
hidden_state = patch_embeds.flatten(2).transpose(1, 2)
# tile embeddings
_, num_patches, dim = hidden_state.shape
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media, num_tiles, -1, dim
)
hidden_state = self.pre_tile_positional_embedding(
hidden_state, aspect_ratio_ids
)
# apply cls token
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media * num_tiles, num_patches, dim
)
hidden_state = self.apply_class_embedding(hidden_state)
num_patches += 1
# apply position embeddings
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media, num_tiles, num_patches, dim
)
hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)
# apply encoder
hidden_state = self.layernorm_pre(hidden_state)
# Compute the number of tokens to pad
num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
# Compute padding tuple for pad function
padding = (
0,
0,
0,
num_padding_patches,
) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
# Pad the tensor
hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
slice_index = -num_padding_patches if num_padding_patches > 0 else None
if attention_mask is not None:
attention_mask = attention_mask.reshape(
batch_size * num_concurrent_media, -1
)
attention_mask = _prepare_aspect_ratio_attention_mask(
aspect_ratio_mask=attention_mask,
num_patches=self.num_patches,
target_length=hidden_state.shape[2],
dtype=self.dtype,
)
hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim)
hidden_state, all_intermediate_hidden_states = self.transformer(
hidden_state,
attention_mask=attention_mask,
)
intermediate_hidden_states = [
hidden_state
for idx, hidden_state in enumerate(all_intermediate_hidden_states)
if idx in self.intermediate_layers_indices
]
intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1)
# apply global encoder
hidden_state = self.layernorm_post(hidden_state)
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
dim,
)
hidden_state = self.post_tile_positional_embedding(
hidden_state, aspect_ratio_ids
)
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media,
num_tiles * (num_patches + num_padding_patches),
dim,
)
hidden_state, _ = self.global_transformer(
hidden_state, attention_mask=attention_mask
)
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
dim,
)
hidden_state = hidden_state[:, :, :slice_index]
# adding intermediate layer outputs
hidden_state = hidden_state.reshape(
batch_size, num_concurrent_media, num_tiles, num_patches, dim
)
intermediate_hidden_states = intermediate_hidden_states.reshape(
batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
-1,
)
intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]
intermediate_hidden_states = intermediate_hidden_states.reshape(
batch_size, num_concurrent_media, num_tiles, num_patches, -1
)
hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)
return hidden_state
class MllamaTextCrossAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, *, prefix, config, weights, layer_idx):
super().__init__()
self.config = config
self.num_heads = self.config.num_attention_heads
self.num_key_value_heads = self.config.num_key_value_heads
self.dropout = config.dropout
self.hidden_size = config.hidden_size
self.head_size = config.hidden_size // self.num_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.layer_idx = layer_idx
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
self.num_key_value_heads // weights.process_group.size()
)
self.q_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_proj",
weights=weights,
bias=False,
)
self.k_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.k_proj",
weights=weights,
bias=False,
)
self.v_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.v_proj",
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.q_norm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps
)
self.k_norm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps
)
self.softmax_scale = self.head_size**-0.5
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: Optional[torch.Tensor] = None,
# past_key_value=None,
# attention_mask: Optional[torch.Tensor] = None,
# cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# hidden_states = hidden_states.unsqueeze(0)
# bsz, q_len, _ = hidden_states.size()
(
cross_attention_states,
cu_seqlen_q,
cu_seqlen_k,
indices,
) = cross_attention_states
bs = cu_seqlen_q.size(0) - 1
query_states = self.q_proj(hidden_states)
query_states = query_states.view(bs, -1, self.num_heads, self.head_size)
query_states = self.q_norm(query_states)
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
key_states = key_states.view(bs, -1, self.num_key_value_heads, self.head_size)
value_states = value_states.view(
bs, -1, self.num_key_value_heads, self.head_size
)
key_states = self.k_norm(key_states)
# key_states = key_states.repeat(1, self.num_key_value_groups, 1)
# value_states = value_states.repeat(1, self.num_key_value_groups, 1)
causal = False
# logger.info(
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
# )
# execute sdpa
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
attn_output = fsdpa_op(
query_states,
key_states,
value_states,
attn_mask=None,
dropout_p=0.0,
is_causal=causal,
scale=None,
softmax_mode="None",
recompute_mode=None,
valid_sequence_lengths=None,
)
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
return attn_output
# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText
class MllamaTextMLP(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
shape = x.shape
gate_up_states = self.gate_up_proj(x)
gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size)
result = self.down_proj(
self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1]
)
return result
class FlashLlamaCrossLayer(torch.nn.Module):
"""Cross-attention transformer block with tanh-gated attention and feedforward."""
def __init__(self, *, prefix, config, weights, index) -> None:
layer_idx = index
super().__init__()
self.cross_attn = MllamaTextCrossAttention(
prefix=f"{prefix}.cross_attn",
config=config,
weights=weights,
layer_idx=layer_idx,
)
self.input_layernorm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.cross_attn_attn_gate = torch.nn.Parameter(
weights.get_tensor(f"{prefix}.cross_attn_attn_gate"), requires_grad=False
)
self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.post_attention_layernorm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.cross_attn_mlp_gate = torch.nn.Parameter(
weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False
)
self.layer_idx = layer_idx
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
cross_attention_states, # [ IB, ...]
hpu_attention_meta,
) -> Tuple[torch.Tensor, torch.Tensor]:
if cross_attention_states is None:
return hidden_states, residual
if residual is not None:
hidden_states += residual
indices = cross_attention_states[-1]
out_hidden_states = hidden_states[:]
if len(indices) > 0:
assert max(indices) < hidden_states.shape[0]
hidden_states = hidden_states[indices]
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.cross_attn(
hidden_states=hidden_states,
# attention_mask=cross_attention_mask,
cross_attention_states=cross_attention_states,
)
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
out_hidden_states[indices] = hidden_states
hidden_states = out_hidden_states
return hidden_states, None
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText
class MllamaTextRMSNorm(nn.Module):
def __init__(self, weight, eps):
super().__init__()
self.weight = weight
self.variance_epsilon = eps
@classmethod
def load(cls, *, prefix, weights, eps):
weight = nn.Parameter(
weights.get_tensor(f"{prefix}.weight"), requires_grad=False
)
return cls(weight=weight, eps=eps)
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class FlashMllamaForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = None
config.vision_config.speculator = config.speculator
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
config.text_config._attn_implementation = "sdpa"
self.hidden_size = config.text_config.hidden_size
self.vision_model = MllamaVisionModel(
prefix="vision_model", config=config.vision_config, weights=weights
)
self.multi_modal_projector = FastLinear.load(
prefix="multi_modal_projector", config=config, weights=weights, bias=True
)
self.text_model = FlashLlamaForCausalLM(
prefix="language_model", config=config.text_config, weights=weights
)
self.config = config
self.dtype = weights.dtype
self.device = weights.device
def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask):
if aspect_ratio_ids is None:
raise ValueError(
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
)
# logger.info(f"PIxel values {pixel_values.shape}")
batch_size = pixel_values.shape[0]
vision_states = self.vision_model(
pixel_values, aspect_ratio_ids, aspect_ratio_mask
)
cross_attention_states = self.multi_modal_projector(vision_states).reshape(
-1, vision_states.shape[-2], self.hidden_size
)
_, _, h = cross_attention_states.shape
cross_attention_states = cross_attention_states.view(batch_size, -1, h)
# logger.info(f"cross {cross_attention_states.shape}")
return cross_attention_states
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor],
adapter_data: Optional[torch.Tensor] = None,
# XXX: Putting these as optional so that the cuda warmup calls can go through.
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
):
if cross_attention_states is not None:
seqlen_q = len(image_indices)
n_images = cross_attention_states.shape[0]
seqlen_k = cross_attention_states.shape[1]
device = cross_attention_states.device
if cu_seqlen_prefill is not None:
offset = 0
cu_q = []
indices = []
for index in image_indices:
cu_q.append(offset)
length = seqlen.input_lengths[index].item()
assert index < seqlen.cu_seqlen_q.shape[0]
input_ids_offset = seqlen.cu_seqlen_q[index]
indices.extend(range(input_ids_offset, input_ids_offset + length))
offset += length
cu_q.append(offset)
cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32)
assert max(indices) < input_ids.shape[0]
cu_seqlen_k = (
torch.arange(
n_images + 1,
device=device,
dtype=torch.int32,
)
* seqlen_k
)
else:
cu_seqlen_q = torch.arange(
seqlen_q + 1, device=device, dtype=torch.int32
)
seqlen_k = cross_attention_states.shape[1]
n_images = cross_attention_states.shape[0]
cu_seqlen_k = (
torch.arange(
n_images + 1,
device=device,
dtype=torch.int32,
)
* seqlen_k
)
indices = image_indices[:]
cross_attention_states = (
cross_attention_states,
cu_seqlen_q,
cu_seqlen_k,
indices,
)
outputs = self.text_model(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
lm_head_indices=lm_head_indices,
adapter_data=adapter_data,
cross_attention_states=cross_attention_states,
)
return outputs

View File

@ -29,8 +29,8 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -39,7 +39,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
@ -132,7 +132,6 @@ class FlashNeoxAttention(torch.nn.Module):
head_size=self.head_size,
hidden_size=self.hidden_size,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=True
)
@ -147,9 +146,10 @@ class FlashNeoxAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
@ -165,35 +165,30 @@ class FlashNeoxAttention(torch.nn.Module):
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
kv_cache.store(
key=qkv[:, 1],
value=qkv[:, 2],
slots=slots,
kv_scales=self.kv_scales,
)
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=qkv[:, 0],
key=qkv[:, 1],
value=qkv[:, 2],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
qkv[:, 0],
kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1],
kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2],
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
qkv[:, 0],
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -260,9 +255,10 @@ class FlashNeoXLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
if self.use_parallel_residual:
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
@ -273,9 +269,10 @@ class FlashNeoXLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
@ -296,9 +293,10 @@ class FlashNeoXLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
hidden_states, residual = self.post_attention_layernorm(
@ -349,15 +347,18 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_in(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids)
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
@ -368,9 +369,10 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
@ -399,9 +401,11 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
@ -410,9 +414,10 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -19,7 +19,7 @@ from torch import nn
from typing import Optional, List, Tuple
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
@ -69,20 +69,22 @@ class PaliGemmaForConditionalGeneration(nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused here
pixel_attention_mask: Optional[torch.BoolTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids)
# TODO This is odd but apparently pali gemma position ids start at 1.
if cu_seqlen_prefill is not None:
max_s += 1
position_ids += 1
if pixel_values is not None:
@ -104,10 +106,10 @@ class PaliGemmaForConditionalGeneration(nn.Module):
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
adapter_data=adapter_data,
max_s=max_s,
)
if lm_head_indices is not None:

View File

@ -9,8 +9,8 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -19,7 +19,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
@ -90,7 +90,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
if config.quantize not in ["gptq", "awq", "marlin"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
@ -139,7 +139,6 @@ class FlashPhiAttention(torch.nn.Module):
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
# in llama the dense layer is called "o_proj" and has bias=False
self.dense = TensorParallelRowLinear.load(
@ -160,9 +159,10 @@ class FlashPhiAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
# Compute query, key, value and split
qkv = self.query_key_value(hidden_states)
@ -188,34 +188,29 @@ class FlashPhiAttention(torch.nn.Module):
)
# Reshape key and value and cache
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# Prefill
if cu_seqlen_prefill is not None:
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_scales=self.kv_scales,
kv_cache=kv_cache,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -279,9 +274,10 @@ class FlashPhiLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
@ -291,9 +287,10 @@ class FlashPhiLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
hidden_states = self.resid_dropout(attn_output).add(
@ -342,15 +339,18 @@ class FlashPhiModel(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
@ -361,9 +361,10 @@ class FlashPhiModel(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -393,9 +394,11 @@ class FlashPhiForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
@ -404,9 +407,10 @@ class FlashPhiForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -8,8 +8,8 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -17,7 +17,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
SpeculativeHead,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
@ -86,8 +86,6 @@ class Qwen2Attention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
@ -106,9 +104,11 @@ class Qwen2Attention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
prefill_cache_indices,
):
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
@ -123,36 +123,38 @@ class Qwen2Attention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
window_size_left=self.max_past,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -221,11 +223,13 @@ class Qwen2Layer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
prefill_cache_indices,
):
normed_hidden_states, residual = self.input_layernorm(hidden_states)
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
@ -234,17 +238,21 @@ class Qwen2Layer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
prefill_cache_indices,
)
hidden_states = attn_output + residual
# faster post attention rms norm
hidden_states, residual = self.post_attention_layernorm(hidden_states)
mlp_output = self.mlp(hidden_states)
hidden_states = mlp_output + residual
return hidden_states
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
)
mlp_output = self.mlp(normed_attn_res_output)
return mlp_output, attn_res
class Qwen2Model(torch.nn.Module):
@ -256,6 +264,9 @@ class Qwen2Model(torch.nn.Module):
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
Qwen2Layer(
@ -279,35 +290,42 @@ class Qwen2Model(torch.nn.Module):
def forward(
self,
inputs_embeds: torch.Tensor,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = inputs_embeds
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids,
position_ids, true_max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
hidden_states = layer(
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
prefill_cache_indices,
)
hidden_states, _ = self.norm(hidden_states)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
@ -328,12 +346,6 @@ class Qwen2ForCausalLM(torch.nn.Module):
prefix=f"{prefix}.{suffix}" if prefix else suffix,
weights=weights,
)
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens",
weights=weights,
)
self.max_past = config.sliding_window
self.max_past_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
@ -347,23 +359,34 @@ class Qwen2ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model(
inputs_embeds,
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
true_max_s,
prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -12,14 +12,14 @@ from text_generation_server.layers import (
TensorParallelRowLinear,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastLayerNorm
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.attention import (
attention,
paged_attention,
reshape_and_cache,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -79,7 +79,6 @@ class RWConfig(PretrainedConfig):
self.alibi = False
self.rotary = True
self.rope_theta = rope_theta
self.max_position_embeddings = 2048
self.vocab_size = vocab_size
# Backward compatibility with n_embed kwarg
@ -161,7 +160,6 @@ class FlashRWAttention(torch.nn.Module):
weights=weights,
bias=config.bias,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
)
@ -182,9 +180,10 @@ class FlashRWAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
qkv = self.query_key_value(hidden_states)
@ -201,35 +200,30 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -284,7 +278,6 @@ class FlashRWLargeAttention(torch.nn.Module):
weights=weights,
bias=config.bias,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
)
@ -300,9 +293,10 @@ class FlashRWLargeAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
@ -318,35 +312,36 @@ class FlashRWLargeAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
kv_cache.store(
key=kv[:, :, 0].contiguous(),
value=kv[:, :, 1].contiguous(),
slots=slots,
kv_scales=self.kv_scales,
reshape_and_cache(
kv[:, :, 0].contiguous(),
kv[:, :, 1].contiguous(),
kv_cache[0],
kv_cache[1],
slots,
)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query=query,
key=kv[:, :, 0],
value=kv[:, :, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(),
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(),
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
return self.dense(
@ -429,9 +424,10 @@ class FlashRWLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
if self.parallel_attn:
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -442,9 +438,10 @@ class FlashRWLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
mlp_output = self.mlp(ln_hidden_states)
@ -463,9 +460,10 @@ class FlashRWLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
if self.post_attention_layernorm is not None:
@ -549,9 +547,10 @@ class FlashRWLargeLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
# Layer norm.
ln_attn, ln_mlp, residual = self.ln_layer(hidden_states, residual)
@ -563,9 +562,10 @@ class FlashRWLargeLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
# MLP.
@ -623,15 +623,18 @@ class FlashRWModel(FlashRWPreTrainedModel):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids)
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.h):
@ -642,9 +645,10 @@ class FlashRWModel(FlashRWPreTrainedModel):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
hidden_states, _ = self.ln_f(hidden_states, residual)
@ -671,9 +675,11 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
@ -682,9 +688,10 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -8,8 +8,8 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -18,7 +18,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import (
FastLayerNorm,
@ -32,6 +32,10 @@ def load_multi_mqa(
return _load_multi_mqa_gptq(
config, prefix, weights, bias, head_size, num_heads, hidden_size
)
elif config.quantize == "marlin":
raise RuntimeError(
"santacoder models with marlin quantization are not yet supported"
)
else:
return _load_multi_mqa(
config, prefix, weights, bias, head_size, num_heads, hidden_size
@ -255,7 +259,6 @@ class FlashMQAttention(torch.nn.Module):
self.c_proj = load_row(
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_head_mapping = torch.zeros(
self.num_heads, dtype=torch.int32, device=weights.device
)
@ -265,9 +268,10 @@ class FlashMQAttention(torch.nn.Module):
hidden_states,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
qkv = self.c_attn(hidden_states)
@ -280,35 +284,32 @@ class FlashMQAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size)
kv_cache.store(
key=key_value[:, 0],
value=key_value[:, 1],
slots=slots,
kv_scales=self.kv_scales,
reshape_and_cache(
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=query,
key=key_value[:, 0],
value=key_value[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1],
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -370,18 +371,20 @@ class Block(nn.Module):
residual,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.self_attn(
hidden_states,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
hidden_states, residual = self.ln_2(hidden_states, residual)
@ -432,9 +435,10 @@ class FlashSantacoderModel(nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
) -> torch.Tensor:
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
@ -448,9 +452,10 @@ class FlashSantacoderModel(nn.Module):
residual,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
hidden_states, _ = self.ln_f(hidden_states, residual)
@ -479,9 +484,11 @@ class FlashSantacoderForCausalLM(nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
@ -490,9 +497,10 @@ class FlashSantacoderForCausalLM(nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -29,19 +29,17 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import (
FastLayerNorm,
FastRMSNorm,
@ -112,31 +110,17 @@ class Starcoder2Config(PretrainedConfig):
)
def load_attention(config, prefix, weights, layer_id):
prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
head_size = config.hidden_size // config.num_attention_heads
sizes = [
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
]
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
base_layer = _load_gqa(config, prefix, weights)
return _load_gqa(config, prefix, weights)
else:
base_layer = TensorParallelColumnLinear.load_multi(
return TensorParallelColumnLinear.load_multi(
config,
prefixes=prefixes,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=config.use_bias,
)
return TensorParallelMultiAdapterLinear.load(
base_layer=base_layer,
layer_id=layer_id,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
def _load_gqa(config, prefix: str, weights):
@ -174,7 +158,6 @@ def _load_gqa(config, prefix: str, weights):
class Starcoder2Attention(torch.nn.Module):
def __init__(
self,
index: int,
prefix: str,
config,
weights,
@ -206,23 +189,14 @@ class Starcoder2Attention(torch.nn.Module):
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights, index)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.query_key_value = load_attention(config, prefix, weights)
o_proj = TensorParallelRowLinear.load(
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=getattr(config, "use_bias", False),
bias=config.use_bias,
)
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
index,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
@ -235,12 +209,13 @@ class Starcoder2Attention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
max_s,
prefill_cache_indices,
):
qkv = self.query_key_value(hidden_states, adapter_data)
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
@ -253,45 +228,45 @@ class Starcoder2Attention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
window_size_left=self.max_past,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
class Starcoder2MLP(nn.Module):
def __init__(self, prefix, config, weights, index):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
@ -305,42 +280,27 @@ class Starcoder2MLP(nn.Module):
)
)
# Fuse gate and up proj
c_fc = TensorParallelColumnLinear.load(
self.c_fc = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.c_fc",
weights=weights,
bias=config.use_bias,
)
c_proj = TensorParallelRowLinear.load(
self.c_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.c_proj",
weights=weights,
bias=config.use_bias,
)
self.c_fc = TensorParallelMultiAdapterLinear.load(
c_fc,
layer_id=index,
layer_names=[f"{prefix}.c_fc"],
sizes=[config.intermediate_size, config.intermediate_size],
process_group=weights.process_group,
)
self.c_proj = TensorParallelAdapterRowLinear.load(
c_proj,
index,
"c_proj",
process_group=weights.process_group,
)
def forward(self, hidden_states, adapter_data):
hidden_states = self.c_fc(hidden_states, adapter_data)
def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
return self.c_proj(hidden_states, adapter_data)
return self.c_proj(hidden_states)
class Starcoder2GatedMLP(nn.Module):
def __init__(self, index, prefix, config, weights):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
@ -354,47 +314,27 @@ class Starcoder2GatedMLP(nn.Module):
)
)
# Fuse gate and up proj
prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
sizes = [
config.intermediate_size,
config.intermediate_size,
]
gate_up_proj = TensorParallelColumnLinear.load_multi(
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=prefixes,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=config.use_bias,
)
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
index,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=config.use_bias,
)
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
index,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
STARCODER2_NORMALIZATION_CLASSES = {
@ -413,11 +353,11 @@ class Starcoder2Layer(nn.Module):
super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = Starcoder2Attention(
prefix=f"{prefix}.self_attn", config=config, weights=weights, index=layer_id
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id
prefix=f"{prefix}.mlp", config=config, weights=weights
)
self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
@ -439,10 +379,11 @@ class Starcoder2Layer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
max_s,
prefill_cache_indices,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -453,10 +394,11 @@ class Starcoder2Layer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
max_s,
prefill_cache_indices,
)
# faster post attention rms norm
@ -464,7 +406,7 @@ class Starcoder2Layer(nn.Module):
attn_output, res
)
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
mlp_output = self.mlp(normed_attn_res_output)
return mlp_output, attn_res
@ -505,16 +447,20 @@ class Starcoder2Model(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
adapter_data,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, true_max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
@ -525,10 +471,11 @@ class Starcoder2Model(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
max_s,
prefill_cache_indices,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -572,22 +519,34 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
max_s,
true_max_s,
prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

Some files were not shown because too many files have changed in this diff Show More