Compare commits

..

1 Commits
main ... v3.1.1

Author SHA1 Message Date
Nicolas Patry
c34bd9d8d9
3.1.1 Release. 2025-03-04 18:11:30 +01:00
343 changed files with 38164 additions and 35529 deletions

View File

@ -45,7 +45,7 @@ jobs:
uses: actions/github-script@v7
with:
script: |
core.exportVariable('ACTIONS_RESULTS_URL', process.env.ACTIONS_RESULTS_URL || '');
core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || '');
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
- name: Extract TensorRT-LLM version
@ -124,15 +124,6 @@ jobs:
export extra_pytest="--neuron"
export target=""
;;
gaudi)
export dockerfile="Dockerfile_gaudi"
export label_extension="-gaudi"
export docker_volume="/mnt/cache"
export docker_devices=""
export runs_on="itac-bm-emr-gaudi3-dell-2gaudi"
export platform=""
export extra_pytest="--gaudi"
export target=""
esac
echo $dockerfile
echo "Dockerfile=${dockerfile}"
@ -223,7 +214,7 @@ jobs:
PLATFORM=${{ env.PLATFORM }}
build_type=${{ env.BUILD_TYPE }}
sccache_gha_enabled=on
actions_results_url=${{ env.ACTIONS_RESULTS_URL }}
actions_cache_url=${{ env.ACTIONS_CACHE_URL }}
actions_runtime_token=${{ env.ACTIONS_RUNTIME_TOKEN }}
target: ${{ env.TARGET }}
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
@ -233,12 +224,7 @@ jobs:
- name: Final
id: final
run: |
if [ "${{ github.event_name }}" = "pull_request" ]; then
echo "docker_image=docker.io/huggingface/text-generation-inference-ci:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
else
echo "docker_image=ghcr.io/huggingface/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
fi
echo "docker_image=docker.io/huggingface/text-generation-inference-ci:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT"
echo "docker_volume=${{ env.DOCKER_VOLUME }}" >> "$GITHUB_OUTPUT"
echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT"

View File

@ -20,8 +20,6 @@ on:
- "Dockerfile"
- "Dockerfile_amd"
- "Dockerfile_intel"
- "Dockerfile.neuron"
- "Dockerfile_gaudi"
branches:
- "main"
workflow_dispatch:
@ -39,7 +37,7 @@ jobs:
# fail-fast is true by default
fail-fast: false
matrix:
hardware: ["cuda", "cuda-trtllm", "rocm", "intel-xpu", "intel-cpu", "neuron", "gaudi"]
hardware: ["cuda", "cuda-trtllm", "rocm", "intel-xpu", "intel-cpu", "neuron"]
uses: ./.github/workflows/build.yaml # calls the one above ^
permissions:
contents: write

View File

@ -1,53 +0,0 @@
name: "Nix Build Docker image"
on:
pull_request:
push:
branches:
- 'main'
tags:
- 'v*'
concurrency:
group: nix-image-${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
build_nix_image:
runs-on:
group: aws-highmemory-32-plus-priv
steps:
- uses: actions/checkout@v4
- uses: cachix/install-nix-action@v27
with:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
with:
name: huggingface
# If you chose signing key for write access
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
env:
USER: github_runner
- name: Build
run: nix build .#dockerImage
- name: Initialize Docker Buildx
uses: docker/setup-buildx-action@v3
with:
install: true
buildkitd-config: /tmp/buildkitd.toml
- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1
- name: Login to internal Container Registry
# if: github.event_name != 'pull_request'
uses: docker/login-action@v3
with:
username: ${{ secrets.REGISTRY_USERNAME }}
password: ${{ secrets.REGISTRY_PASSWORD }}
registry: registry.internal.huggingface.tech
- name: Push to docker
run: |
if [ "${{ github.event_name }}" = "pull_request" ]; then
export TAG=nix-sha-${{ env.GITHUB_SHA_SHORT }}
else
export TAG=${{ github.ref_name }}-nix
fi
export IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:$TAG
nix-shell -p skopeo --command "skopeo --insecure-policy copy docker-archive:$(readlink -f ./result) docker://$IMAGE --dest-compress-format zstd"

View File

@ -20,7 +20,7 @@ jobs:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
with:
name: huggingface
name: text-generation-inference
# If you chose signing key for write access
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
env:

View File

@ -7,7 +7,6 @@ on:
- "proto/**"
- "router/**"
- "launcher/**"
- "backends/**"
- "Cargo.lock"
- "rust-toolchain.toml"
concurrency:
@ -25,7 +24,7 @@ jobs:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
with:
name: huggingface
name: text-generation-inference
# If you chose signing key for write access
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
env:

View File

@ -8,7 +8,6 @@ on:
- "proto/**"
- "router/**"
- "launcher/**"
- "backends/**"
- "Cargo.lock"
- "rust-toolchain.toml"
@ -47,7 +46,7 @@ jobs:
- name: Download locked kernels
run: |
source ./.venv/bin/activate
kernels download server
hf-kernels download server
- name: Run server tests
run: |
source ./.venv/bin/activate

1
.gitignore vendored
View File

@ -28,4 +28,3 @@ server/fbgemmm
hl-smi_log*.txt
.graph_dumps
out
hqt_output

774
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -21,7 +21,7 @@ default-members = [
resolver = "2"
[workspace.package]
version = "3.3.4-dev0"
version = "3.1.1"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"
@ -29,7 +29,7 @@ homepage = "https://github.com/huggingface/text-generation-inference"
[workspace.dependencies]
base64 = "0.22.0"
tokenizers = { version = "0.20.0", features = ["http"] }
hf-hub = { version = "0.4.2", features = ["tokio"] }
hf-hub = { version = "0.4.1", features = ["tokio"] }
metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
minijinja = { version = "2.2.0", features = ["json"] }

View File

@ -1,5 +1,5 @@
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.85.0 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -48,7 +48,7 @@ FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
WORKDIR /usr/src/
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
ARG PYTORCH_VERSION=2.7
ARG PYTORCH_VERSION=2.6
ARG PYTHON_VERSION=3.11
# Keep in sync with `server/pyproject.toml
@ -65,7 +65,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
ENV PATH="$PATH:/root/.local/bin"
RUN uv python install ${PYTHON_VERSION}
RUN uv venv --python ${PYTHON_VERSION} && uv pip install torch==${PYTORCH_VERSION} torchvision pip setuptools packaging
RUN uv venv --python ${PYTHON_VERSION} && uv pip install torch==${PYTORCH_VERSION} pip setuptools packaging
ENV VIRTUAL_ENV=/usr/src/.venv/
ENV PATH="$PATH:/usr/src/.venv/bin/"
@ -121,6 +121,13 @@ COPY server/Makefile-awq Makefile
# Build specific version of transformers
RUN . .venv/bin/activate && make build-awq
# Build Lorax Punica kernels
FROM kernel-builder AS lorax-punica-builder
WORKDIR /usr/src
COPY server/Makefile-lorax-punica Makefile
# Build specific version of transformers
RUN . .venv/bin/activate && TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
# Build Transformers CUDA kernels
FROM kernel-builder AS custom-kernels-builder
WORKDIR /usr/src
@ -158,9 +165,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
git \
&& rm -rf /var/lib/apt/lists/*
# RUN curl -LsSf https://astral.sh/uv/install.sh | sh
# ENV PATH="$PATH:/root/.local/bin"
COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
ENV PATH="$PATH:/root/.local/bin"
# Install flash-attention dependencies
# RUN pip install einops --no-cache-dir
@ -177,12 +183,12 @@ COPY server server
COPY server/Makefile server/Makefile
ENV HF_KERNELS_CACHE=/kernels
RUN cd server && \
uv sync --frozen --extra gen --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --extra torch --no-install-project --active && \
uv sync --frozen --extra gen --extra attention --extra bnb --extra accelerate --extra compressed-tensors --extra marlin --extra moe --extra quantize --extra peft --extra outlines --no-install-project --active && \
make gen-server-raw && \
kernels download .
hf-kernels download .
RUN cd server && \
uv sync --frozen --extra gen --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --extra torch --active --python=${PYTHON_VERSION} && \
uv sync --frozen --extra gen --extra attention --extra bnb --extra accelerate --extra compressed-tensors --extra marlin --extra moe --extra quantize --extra peft --extra outlines --active --python=${PYTHON_VERSION} && \
uv pip install nvidia-nccl-cu12==2.25.1 && \
pwd && \
text-generation-server --help
@ -203,6 +209,8 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from awq kernels builder
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from lorax punica kernels builder
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from mamba builder
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages

View File

@ -5,7 +5,7 @@ RUN mkdir -p /tgi
# Fetch the optimum-neuron sources directly to avoid relying on pypi deployments
FROM alpine AS optimum-neuron
RUN mkdir -p /optimum-neuron
ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.2.0.tar.gz /optimum-neuron/sources.tar.gz
ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.0.28.tar.gz /optimum-neuron/sources.tar.gz
RUN tar -C /optimum-neuron -xf /optimum-neuron/sources.tar.gz --strip-components=1
# Build cargo components (adapted from TGI original Dockerfile)
@ -18,7 +18,8 @@ RUN apt-get update -y \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.1 --profile minimal -y
COPY rust-toolchain.toml rust-toolchain.toml
RUN curl -sSf https://sh.rustup.rs | sh -s -- -y --no-modify-path --default-toolchain none
ENV PATH="/root/.cargo/bin:${PATH}"
RUN cargo install cargo-chef --locked
@ -108,10 +109,10 @@ RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEU
# Install neuronx packages
RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \
aws-neuronx-dkms=2.20.28.0 \
aws-neuronx-collectives=2.24.59.0-838c7fc8b \
aws-neuronx-runtime-lib=2.24.53.0-f239092cc \
aws-neuronx-tools=2.22.61.0 \
aws-neuronx-dkms=2.18.20.0 \
aws-neuronx-collectives=2.22.33.0-d2128d1aa \
aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 \
aws-neuronx-tools=2.19.0.0 \
libxml2 \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
@ -120,15 +121,16 @@ ENV PATH="/opt/bin/:/opt/aws/neuron/bin:${PATH}"
# Install manually torch CPU version to avoid pulling CUDA
RUN pip3 install \
torch==2.5.1 \
torchvision==0.20.1 \
torch==2.1.2 \
torchvision==0.16.2 \
--index-url https://download.pytorch.org/whl/cpu
RUN pip3 install \
neuronx-cc==2.17.194.0 \
torch-neuronx==2.5.1.2.6.0 \
neuronx-distributed==0.11.0 \
libneuronxla==2.2.1630.0 \
neuronx-cc==2.15.143.0 \
torch-neuronx==2.1.2.2.3.2 \
transformers-neuronx==0.12.313 \
neuronx-distributed==0.9.0 \
libneuronxla==2.0.5347.0 \
--extra-index-url=https://pip.repos.neuron.amazonaws.com
# Install HuggingFace packages
@ -159,7 +161,7 @@ RUN pip install dist/text_generation_server*.tar.gz
# Final image
FROM neuron
COPY backends/neuron/tgi_entry_point.py /tgi_entry_point.py
COPY backends/neuron/tgi_env.py /tgi_env.py
COPY backends/neuron/tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh

View File

@ -6,7 +6,7 @@
FROM nixos/nix:2.18.8 AS builder
RUN echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf
RUN nix profile install nixpkgs#cachix
RUN cachix use huggingface
RUN cachix use text-generation-inference
WORKDIR /root
ADD . .
RUN nix build .

View File

@ -1,5 +1,5 @@
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.85.0 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -41,244 +41,303 @@ COPY backends backends
COPY launcher launcher
RUN cargo build --profile release-opt --frozen
FROM rocm/dev-ubuntu-22.04:6.3.1-complete AS base
# Text Generation Inference base image for RoCm
FROM rocm/dev-ubuntu-22.04:6.2 AS base
ARG HIPBLASLT_BRANCH="4d40e36"
ARG HIPBLAS_COMMON_BRANCH="7c1566b"
ARG LEGACY_HIPBLASLT_OPTION=
ARG RCCL_BRANCH="648a58d"
ARG RCCL_REPO="https://github.com/ROCm/rccl"
ARG TRITON_BRANCH="e5be006"
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
ARG PYTORCH_BRANCH="3a585126"
ARG PYTORCH_VISION_BRANCH="v0.19.1"
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="b7d29fb"
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
ARG AITER_BRANCH="21d47a9"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
ENV PATH=/opt/rocm/llvm/bin:$PATH
ENV ROCM_PATH=/opt/rocm
ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
ARG PYTHON_VERSION=3.11
RUN mkdir -p /app
WORKDIR /app
ENV DEBIAN_FRONTEND=noninteractive
# Install Python and other dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
ca-certificates \
ccache \
curl \
git \
ninja-build \
cmake \
software-properties-common \
python3.11-dev \
python3.11-venv && \
rm -rf /var/lib/apt/lists/*
build-essential \
ca-certificates \
ccache \
curl \
git \
make \
libmsgpack-dev \
libssl-dev \
llvm-dev \
g++ \
# Needed to build VLLM & flash.
rocthrust-dev \
hipsparse-dev \
hipblas-dev \
hipcub-dev \
rocblas-dev \
hiprand-dev \
hipfft-dev \
rocrand-dev \
miopen-hip-dev \
hipsolver-dev \
rccl-dev \
cmake \
python3.11-venv && \
rm -rf /var/lib/apt/lists/*
COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
ENV PATH="$PATH:/root/.local/bin"
RUN uv python install ${PYTHON_VERSION}
RUN uv venv --python ${PYTHON_VERSION} && uv pip install pip setuptools packaging
ENV VIRTUAL_ENV=/usr/src/.venv/
ENV PATH="$PATH:/usr/src/.venv/bin/"
# Keep in sync with `server/pyproject.toml
ARG MAMBA_VERSION=23.1.0-1
ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH=/opt/conda/bin:$PATH
RUN . .venv/bin/activate && pip install -U packaging cmake ninja wheel setuptools pybind11 Cython
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba
# translating Docker's TARGETPLATFORM into mamba arches
RUN case ${TARGETPLATFORM} in \
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
*) MAMBA_ARCH=x86_64 ;; \
esac && \
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
RUN chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \
mamba init && \
rm ~/mambaforge.sh
# RUN conda install intel::mkl-static intel::mkl-include
# Install pytorch
# On arm64 we exit with an error code
RUN case ${TARGETPLATFORM} in \
"linux/arm64") exit 1 ;; \
*) /opt/conda/bin/conda update -y conda && \
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
esac && \
/opt/conda/bin/conda clean -ya
# Install flash-attention, torch dependencies
RUN python3 -m pip install --upgrade pip uv && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/*
RUN conda install mkl=2021
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/
ARG COMMON_WORKDIR=/
WORKDIR ${COMMON_WORKDIR}
# Install HIPBLASLt
FROM base AS build_hipblaslt
ARG HIPBLASLT_BRANCH
ARG HIPBLAS_COMMON_BRANCH
# Set to "--legacy_hipblas_direct" for ROCm<=6.2
ARG LEGACY_HIPBLASLT_OPTION
RUN git clone https://github.com/ROCm/hipBLAS-common.git
RUN . .venv/bin/activate && cd hipBLAS-common \
&& git checkout ${HIPBLAS_COMMON_BRANCH} \
&& mkdir build \
&& cd build \
&& cmake .. \
&& make package \
&& dpkg -i ./*.deb
RUN git clone https://github.com/ROCm/hipBLASLt
RUN . .venv/bin/activate && cd hipBLASLt \
ARG HIPBLASLT_BRANCH="e6da924"
RUN git clone https://github.com/ROCm/hipBLASLt.git \
&& cd hipBLASLt \
&& git checkout ${HIPBLASLT_BRANCH} \
&& ./install.sh -d --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \
&& cd build/release \
&& make package
RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install
FROM scratch AS export_hipblaslt
ARG COMMON_WORKDIR
COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb /
# RCCL build stages
FROM base AS build_rccl
ARG RCCL_BRANCH
ARG RCCL_REPO
RUN git clone ${RCCL_REPO}
RUN . .venv/bin/activate && cd rccl \
ARG RCCL_BRANCH="rocm-6.2.0"
RUN git clone https://github.com/ROCm/rccl \
&& cd rccl \
&& git checkout ${RCCL_BRANCH} \
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install
FROM scratch AS export_rccl
ARG COMMON_WORKDIR
COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb /
# Triton build stages
FROM base AS build_triton
ARG TRITON_BRANCH
ARG TRITON_REPO
RUN git clone ${TRITON_REPO}
RUN . .venv/bin/activate && cd triton \
ARG TRITON_BRANCH="e192dba"
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \
&& cd triton \
&& git checkout ${TRITON_BRANCH} \
&& cd python \
&& python3 setup.py bdist_wheel --dist-dir=dist
RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install
FROM scratch AS export_triton
ARG COMMON_WORKDIR
COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl /
# # AMD-SMI build stages
FROM base AS build_amdsmi
RUN . .venv/bin/activate && cd /opt/rocm/share/amd_smi \
RUN cd /opt/rocm/share/amd_smi \
&& pip wheel . --wheel-dir=dist
RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install
FROM scratch AS export_amdsmi
COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl /
FROM base AS build_pytorch
ARG PYTORCH_BRANCH
ARG PYTORCH_VISION_BRANCH
ARG PYTORCH_REPO
ARG PYTORCH_VISION_REPO
ARG FA_BRANCH
ARG FA_REPO
RUN git clone ${PYTORCH_REPO} pytorch
RUN . .venv/bin/activate && cd pytorch && git checkout ${PYTORCH_BRANCH} && \
pip install -r requirements.txt && git submodule update --init --recursive \
&& python3 tools/amd_build/build_amd.py \
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
&& pip install dist/*.whl
RUN git clone ${PYTORCH_VISION_REPO} vision
RUN . .venv/bin/activate && cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
&& python3 setup.py bdist_wheel --dist-dir=dist \
&& pip install dist/*.whl
RUN git clone ${FA_REPO}
RUN . .venv/bin/activate && cd flash-attention \
&& git checkout ${FA_BRANCH} \
&& git submodule update --init \
&& MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
&& cp /app/vision/dist/*.whl /app/install \
&& cp /app/flash-attention/dist/*.whl /app/install
FROM base AS final
RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \
dpkg -i /install/*deb \
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status
RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \
dpkg -i /install/*deb \
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
. .venv/bin/activate && \
pip install /install/*.whl
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
. .venv/bin/activate && \
pip install /install/*.whl
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
. .venv/bin/activate && \
pip install /install/*.whl
FROM base as build_pytorch
ARG AITER_REPO
ARG AITER_BRANCH
RUN git clone --recursive ${AITER_REPO}
RUN . .venv/bin/activate && cd aiter \
&& git checkout ${AITER_BRANCH} \
&& git submodule update --init --recursive \
&& pip install -r requirements.txt \
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
if ls /install/*.deb; then \
dpkg -i /install/*.deb \
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
fi
RUN rm -rf /var/lib/apt/lists/*
ARG BUILD_ENVIRONMENT=pytorch-linux-jammy-rocm6.2-py3.11
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
# A commit to fix the output scaling factor issue in _scaled_mm
# Not yet in 2.5.0-rc1
ARG PYTORCH_BRANCH="cedc116"
ARG PYTORCH_VISION_BRANCH="v0.19.1"
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
RUN git clone ${PYTORCH_REPO} pytorch \
&& cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \
&& pip install -r requirements.txt --no-cache-dir \
&& python tools/amd_build/build_amd.py \
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist
FROM scratch as export_pytorch
ARG COMMON_WORKDIR
COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl /
FROM base AS install_deps
ARG COMMON_WORKDIR
# Install hipblaslt
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
if ls /install/*.deb; then \
dpkg -i /install/*.deb \
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
fi
RUN --mount=type=bind,from=export_rccl,src=/,target=/install \
if ls /install/*.deb; then \
dpkg -i /install/*.deb \
# RCCL needs to be installed twice
&& dpkg -i /install/*.deb \
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \
fi
RUN --mount=type=bind,from=export_triton,src=/,target=/install \
if ls /install/*.whl; then \
# Preemptively uninstall to prevent pip same-version no-installs
pip uninstall -y triton \
&& pip install /install/*.whl; \
fi
RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \
# Preemptively uninstall to prevent pip same-version no-installs
pip uninstall -y amdsmi \
&& pip install /install/*.whl;
RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
if ls /install/*.whl; then \
# Preemptively uninstall to prevent pip same-version no-installs
pip uninstall -y torch torchvision \
&& pip install /install/*.whl; \
fi
FROM install_deps AS kernel-builder
FROM final AS kernel-builder
# # Build vllm kernels
FROM kernel-builder AS vllm-builder
WORKDIR /usr/src
COPY server/Makefile-vllm Makefile
RUN . .venv/bin/activate && pip install setuptools_scm
RUN pip install setuptools_scm
# Build specific version of vllm
RUN . .venv/bin/activate && make build-vllm-rocm
RUN make build-vllm-rocm
# Build Flash Attention v2 kernels
FROM kernel-builder AS flash-att-v2-builder
WORKDIR /usr/src
COPY server/Makefile-flash-att-v2 Makefile
# Build specific version of flash attention v2
RUN make build-flash-attention-v2-rocm
# Build Transformers CUDA kernels (gpt-neox and bloom)
FROM kernel-builder AS custom-kernels-builder
WORKDIR /usr/src
COPY server/custom_kernels/ .
RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
RUN python setup.py build
# Build exllama kernels
FROM kernel-builder AS exllama-kernels-builder
WORKDIR /usr/src
COPY server/exllama_kernels/ .
RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
RUN python setup.py build
# Build exllama v2 kernels
FROM kernel-builder AS exllamav2-kernels-builder
WORKDIR /usr/src
COPY server/exllamav2_kernels/ .
RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
RUN python setup.py build
FROM kernel-builder AS marlin-kernels
WORKDIR /usr/src
ENV MARLIN_KERNELS_BRANCH=v0.3.6
ENV VLLM_TARGET_DEVICE=rocm
RUN . .venv/bin/activate && git clone https://github.com/danieldk/marlin-kernels.git && \
RUN git clone https://github.com/danieldk/marlin-kernels.git && \
cd marlin-kernels && \
git checkout ${MARLIN_KERNELS_BRANCH} && \
python3 setup.py bdist_wheel --dist-dir=dist
python setup.py install
FROM kernel-builder AS moe-kernels
WORKDIR /usr/src
ENV MOE_KERNELS_BRANCH=v0.8.2
ENV VLLM_TARGET_DEVICE=rocm
RUN . .venv/bin/activate && git clone https://github.com/danieldk/moe-kernels.git && \
RUN git clone https://github.com/danieldk/moe-kernels.git && \
cd moe-kernels && \
git checkout ${MOE_KERNELS_BRANCH} && \
python3 setup.py bdist_wheel --dist-dir=dist
python setup.py install
FROM final AS base-copy
FROM install_deps AS base-copy
# Text Generation Inference base env
ENV HF_HOME=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80
ENV VIRTUAL_ENV=/app/.venv/
ENV PATH="$PATH:/app/.venv/bin/"
# Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from exllama kernels builder
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from marlin kernels
COPY --from=marlin-kernels /usr/src/marlin-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from moe kernels
COPY --from=moe-kernels /usr/src/moe-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Install server
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
ENV UV_SYSTEM_PYTHON=1
RUN cd server && \
uv pip install grpcio-tools mypy-protobuf && \
uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir && \
pip install -U pip uv && \
uv sync --frozen --extra gen --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --no-install-project && \
. ./.venv/bin/activate && \
make gen-server-raw
RUN cd server && \
uv sync --frozen --extra gen --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines && \
. ./.venv/bin/activate && \
pwd && \
text-generation-server --help
RUN --mount=type=bind,from=vllm-builder,src=/app/vllm/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=custom-kernels-builder,src=/app/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=custom-kernels-builder,src=/app/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=exllama-kernels-builder,src=/app/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=exllamav2-kernels-builder,src=/app/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=marlin-kernels,src=/app/marlin-kernels/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=moe-kernels,src=/app/moe-kernels/dist,target=/install \
uv pip install /install/*.whl
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
# AWS Sagemaker compatible image
FROM base AS sagemaker
@ -309,6 +368,4 @@ COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh
ENTRYPOINT ["/tgi-entrypoint.sh"]
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/root/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib"
ENV PYTHONPATH=/app/.venv/lib/python3.11/site-packages
# CMD ["--json-output"]
CMD ["--json-output"]

View File

@ -1,9 +1,9 @@
# Those arguments are required to build the image
ARG HABANA_VERSION=1.21.0
ARG PYTORCH_VERSION=2.6.0
ARG HABANA_VERSION
ARG PYTORCH_VERSION
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.85.0 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -57,12 +57,9 @@ ARG PYTORCH_VERSION
FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base
ENV ATTENTION=paged
ENV ATTENTION=default
ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0
ENV PT_HPU_LAZY_MODE=1
ENV PT_HPU_WEIGHT_SHARING=0
ENV VLLM_EXPONENTIAL_BUCKETING=true
# Text Generation Inference base env
ENV HF_HOME=/data \
@ -95,9 +92,10 @@ RUN cd server && \
make gen-server && \
pip install --no-deps -r requirements.txt && \
bash ./dill-0.3.8-patch.sh && \
pip install outlines~=0.0.34 && \
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
pip install . --no-cache-dir
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix
RUN pip install compressed-tensors==0.9.1
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark

View File

@ -1,6 +1,6 @@
ARG PLATFORM=xpu
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.85.0 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -45,7 +45,7 @@ RUN cargo build --profile release-opt --frozen
# Text Generation Inference base image for Intel
FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04 AS xpu
FROM intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04 AS xpu
USER root
@ -87,7 +87,7 @@ RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https:/
RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc libnl-genl-3-200
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-pti-dev-0.9
# Text Generation Inference base env
ENV HF_HOME=/data \
@ -96,10 +96,13 @@ ENV HF_HOME=/data \
WORKDIR /usr/src
RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/torch-2.5.0a0%2Bgite84e33f-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/torchaudio-2.5.0a0%2B56bc006-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/torchvision-0.20.0a0%2B8e8a208-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/oneccl_bind_pt-2.5.0%2Bxpu-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/xpu
RUN pip install triton-xpu==3.0.0b2 --no-cache-dir
# Install server
COPY proto proto
@ -111,14 +114,15 @@ RUN cd server && \
pip install -U pip uv && \
uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/pti/0.9/lib:/opt/conda/lib
ENV CCL_ZE_IPC_EXCHANGE=sockets
ENV TORCH_LLM_ALLREDUCE=1
ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
#ENV TORCH_LLM_ALLREDUCE=1
#ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.7.0%2Bxpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.7.10%2Bxpu-cp311-cp311-linux_x86_64.whl
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 1ccf72b2d11cd00b47aef6d6cd054c088aa6f083
RUN cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc,ats-m150' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
@ -181,13 +185,13 @@ RUN case ${TARGETPLATFORM} in \
RUN conda install -c conda-forge gperftools mkl
RUN pip install torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cpu
RUN pip install triton==3.2.0 py-libnuma
RUN pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cpu
RUN pip install triton==3.1.0 py-libnuma
WORKDIR /usr/src
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/intel_extension_for_pytorch-2.7.0%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/oneccl_bind_pt-2.7.0%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/intel_extension_for_pytorch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/oneccl_bind_pt-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so

View File

@ -1,15 +1,13 @@
FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu24.04 AS deps
ARG llamacpp_version=b4827
ARG llamacpp_version=b4651
ARG llamacpp_cuda=OFF
ARG llamacpp_native=ON
ARG llamacpp_cpu_arm_arch=native
ARG cuda_arch=75-real;80-real;86-real;89-real;90-real
WORKDIR /opt/src
ENV DEBIAN_FRONTEND=noninteractive
RUN apt update && apt upgrade -y && apt install -y \
RUN apt update && apt install -y \
clang \
cmake \
curl \
@ -19,10 +17,9 @@ RUN apt update && apt upgrade -y && apt install -y \
pkg-config \
tar
ADD https://github.com/ggml-org/llama.cpp/archive/refs/tags/${llamacpp_version}.tar.gz /opt/src/
RUN mkdir -p llama.cpp \
&& tar -xzf ${llamacpp_version}.tar.gz -C llama.cpp --strip-components=1 \
&& cd llama.cpp \
ADD https://github.com/ggerganov/llama.cpp/archive/refs/tags/${llamacpp_version}.tar.gz /opt/src/
RUN tar -xzf ${llamacpp_version}.tar.gz \
&& cd llama.cpp-${llamacpp_version} \
&& cmake -B build \
-DCMAKE_INSTALL_PREFIX=/usr \
-DCMAKE_INSTALL_LIBDIR=/usr/lib \
@ -30,8 +27,6 @@ RUN mkdir -p llama.cpp \
-DCMAKE_CXX_COMPILER=clang++ \
-DCMAKE_CUDA_ARCHITECTURES=${cuda_arch} \
-DGGML_CUDA=${llamacpp_cuda} \
-DGGML_NATIVE=${llamacpp_native} \
-DGGML_CPU_ARM_ARCH=${llamacpp_cpu_arm_arch} \
-DLLAMA_BUILD_COMMON=OFF \
-DLLAMA_BUILD_TESTS=OFF \
-DLLAMA_BUILD_EXAMPLES=OFF \
@ -41,7 +36,7 @@ RUN mkdir -p llama.cpp \
WORKDIR /app
COPY rust-toolchain.toml rust-toolchain.toml
RUN curl -sSf https://sh.rustup.rs | sh -s -- --no-modify-path --default-toolchain 1.85.1 --profile minimal -y
RUN curl -sSf https://sh.rustup.rs | sh -s -- --no-modify-path --default-toolchain 1.85.0 --profile minimal -y
ENV PATH="/root/.cargo/bin:$PATH"
RUN cargo install cargo-chef --locked
@ -53,18 +48,16 @@ FROM deps AS builder
COPY --from=planner /app/recipe.json recipe.json
RUN cargo chef cook \
--recipe-path recipe.json \
--profile release \
--profile release-opt \
--package text-generation-router-llamacpp
COPY . .
RUN cargo build \
--profile release \
--profile release-opt \
--package text-generation-router-llamacpp --frozen
FROM nvidia/cuda:12.8.0-cudnn-runtime-ubuntu24.04
WORKDIR /app
ENV DEBIAN_FRONTEND=noninteractive
RUN apt update && apt upgrade -y && apt install -y \
RUN apt update && apt install -y \
python3-venv \
python3-pip
@ -72,16 +65,11 @@ RUN python3 -m venv /venv
ENV PATH="/venv/bin:$PATH"
COPY backends/llamacpp/requirements.txt requirements.txt
COPY --from=builder /opt/src/llama.cpp/gguf-py gguf-py
COPY --from=builder /opt/src/llama.cpp/convert_hf_to_gguf.py /bin/
RUN pip3 install --no-cache-dir \
-r requirements.txt \
-e gguf-py
RUN pip3 install --no-cache-dir -r requirements.txt
COPY --from=builder /usr/lib/libllama.so /usr/lib/
COPY --from=builder /usr/lib/libggml*.so /usr/lib/
COPY --from=builder /app/target/release/text-generation-router-llamacpp /usr/bin/
COPY --from=builder /app/target/release-opt/text-generation-router-llamacpp /usr/bin/
ENV HF_HUB_ENABLE_HF_TRANSFER=1

View File

@ -3,9 +3,10 @@ ARG cuda_base=12.8.0
ARG build_type=release
ARG ompi_version=4.1.7
ARG sccache_gha_enabled=off
ARG actions_results_url=""
ARG actions_cache_url=""
ARG actions_runtime_token=""
# CUDA dependent dependencies resolver stage
FROM nvidia/cuda:${cuda_base}-cudnn-devel-ubuntu24.04 AS cuda-builder
@ -65,15 +66,15 @@ WORKDIR /usr/src/text-generation-inference
ARG cuda_arch_list
ARG build_type
ARG sccache_gha_enabled
ARG actions_results_url
ARG actions_cache_url
ARG actions_runtime_token
# Install Rust
ENV PATH="/root/.cargo/bin:$PATH"
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.1 --profile minimal -y && \
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.0 --profile minimal -y && \
chmod -R a+w /root/.rustup && \
chmod -R a+w /root/.cargo && \
cargo install sccache --version ">=0.10.0" --locked
cargo install sccache --locked
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig"
@ -84,7 +85,7 @@ ENV CUDA_ARCH_LIST=${cuda_arch_list}
# SCCACHE Specifics args - before finding a better, more generic, way...
ENV SCCACHE_GHA_ENABLED=${sccache_gha_enabled}
ENV ACTIONS_RESULTS_URL=${actions_results_url}
ENV ACTIONS_CACHE_URL=${actions_cache_url}
ENV ACTIONS_RUNTIME_TOKEN=${actions_runtime_token}
COPY Cargo.lock Cargo.lock

View File

@ -53,6 +53,3 @@ run-falcon-7b-instruct-quantize:
clean:
rm -rf target aml
preview_doc:
doc-builder preview text-generation-inference docs/source --not_python_module

View File

@ -14,7 +14,7 @@
</a>
A Rust, Python and gRPC server for text generation inference. Used in production at [Hugging Face](https://huggingface.co)
to power Hugging Chat, the Inference API and Inference Endpoints.
to power Hugging Chat, the Inference API and Inference Endpoint.
</div>
@ -84,7 +84,7 @@ model=HuggingFaceH4/zephyr-7b-beta
volume=$PWD/data
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.3.4 --model-id $model
ghcr.io/huggingface/text-generation-inference:3.1.1 --model-id $model
```
And then you can make requests like
@ -121,7 +121,7 @@ curl localhost:8080/v1/chat/completions \
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.4-rocm --model-id $model` instead of the command above.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.1-rocm --model-id $model` instead of the command above.
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
```
@ -152,7 +152,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
token=<your cli READ token>
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.3.4 --model-id $model
ghcr.io/huggingface/text-generation-inference:3.1.1 --model-id $model
```
### A note on Shared Memory (shm)
@ -256,7 +256,7 @@ Another option is to install `text-generation-inference` locally using [Nix](htt
we only support Nix on x86_64 Linux with CUDA GPUs. When using Nix, all dependencies can
be pulled from a binary cache, removing the need to build them locally.
First follow the instructions to [install Cachix and enable the Hugging Face cache](https://app.cachix.org/cache/huggingface).
First follow the instructions to [install Cachix and enable the TGI cache](https://app.cachix.org/cache/text-generation-inference).
Setting up the cache is important, otherwise Nix will build many of the dependencies
locally, which can take hours.
@ -264,7 +264,7 @@ After that you can run TGI with `nix run`:
```shell
cd text-generation-inference
nix run --extra-experimental-features nix-command --extra-experimental-features flakes . -- --model-id meta-llama/Llama-3.1-8B-Instruct
nix run --extra-experimental-features nix-command --extra-experimental-features flakes . -- --model-idmeta-llama/Llama-3.1-8B-Instruct
```
**Note:** when you are using Nix on a non-NixOS system, you have to [make some symlinks](https://danieldk.eu/Nix-CUDA-on-non-NixOS-systems#make-runopengl-driverlib-and-symlink-the-driver-library)

View File

@ -1,14 +1,14 @@
mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
mkfile_dir := $(dir $(mkfile_path))
root_dir := ${mkfile_dir}/../..
root_dir := "${mkfile_dir}/../.."
HABANA_VERSION := 1.21.0
PYTORCH_VERSION := 2.6.0
HABANA_VERSION := 1.19.0
PYTORCH_VERSION := 2.5.1
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
image:
docker build --ulimit nofile=4096 -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION)
docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION)
run-local-dev-container:
docker run -it \
@ -47,21 +47,3 @@ local-dev-install: install-dependencies
make install-server && \
make install-router && \
make install-launcher'
# In order to run the integration tests, you need to first build the image (make -C backends/gaudi image)
run-integration-tests:
DOCKER_VOLUME=${root_dir}/data \
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
pytest --durations=0 -s -vv ${root_dir}/integration-tests --gaudi
run-integration-tests-with-all-models:
DOCKER_VOLUME=${root_dir}/data \
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
pytest --durations=0 -s -vv ${root_dir}/integration-tests --gaudi --gaudi-all-models
# This is used to capture the expected outputs for the integration tests offering an easy way to add more models to the integration tests
capture-expected-outputs-for-integration-tests:
pip install -U pip uv
DOCKER_VOLUME=${root_dir}/data \
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
uv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests/capture_expected_outputs.py

View File

@ -96,57 +96,3 @@ curl 127.0.0.1:8080/generate \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
-H 'Content-Type: application/json'
```
### Integration tests
Install the dependencies:
```bash
pip install -r integration-tests/requirements.txt
```
To run the integration tests, you need to first build the image:
```bash
make -C backends/gaudi image
```
Then run the following command to run the integration tests (CI tests):
```bash
make -C backends/gaudi run-integration-tests
```
To run the integration tests with all models, you can run the following command:
```bash
make -C backends/gaudi run-integration-tests-with-all-models
```
To capture the expected outputs for the integration tests, you can run the following command:
```bash
make -C backends/gaudi capture-expected-outputs-for-integration-tests
```
#### How the integration tests works
The integration tests works as follows:
1. Start a tgi server in a container, similar to the command:
```bash
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
-p 8080:80 -v $volume:/data \
-e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \
tgi-gaudi --model-id $model \
--max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 4 --max-batch-prefill-tokens 2048
```
2. Do a /generate request to the server, similar to the command:
```bash
curl 127.0.0.1:8080/generate \
-X POST \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
-H 'Content-Type: application/json'
```
3. Check the output of the server against the expected output:
```python
assert curl_output == expected_output
```
This is the repeated for a set of models and configurations.

View File

@ -1,283 +0,0 @@
# Examples of Docker Commands for Gaudi Backend
This page gives a list of examples of docker run commands for some of the most popular models.
> **Note:** The parameters are chosen for Gaudi2 hardware to maximize performance on this given hardware, please adjust the parameters based on your hardware. For example, if you are using Gaudi3, you may want to increase the batch size.
## Default Precision (BF16)
### Llama3.1-8B on 1 card (BF16)
```bash
model=meta-llama/Meta-Llama-3.1-8B-Instruct
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e PREFILL_BATCH_BUCKET_SIZE=2 \
-e BATCH_BUCKET_SIZE=32 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
```
### Llama3.1-70B 8 cards (BF16)
```bash
model=meta-llama/Meta-Llama-3.1-70B-Instruct
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e BATCH_BUCKET_SIZE=256 \
-e PREFILL_BATCH_BUCKET_SIZE=4 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
```
### Llama2-7B on 1 Card (BF16)
```bash
model=meta-llama/Llama-2-7b-chat-hf
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e PREFILL_BATCH_BUCKET_SIZE=2 \
-e BATCH_BUCKET_SIZE=32 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
```
### Llama2-70B on 8 cards (BF16)
```bash
model=meta-llama/Llama-2-70b-chat-hf
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e BATCH_BUCKET_SIZE=256 \
-e PREFILL_BATCH_BUCKET_SIZE=4 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
```
### Llava-v1.6-Mistral-7B on 1 card (BF16)
```bash
model=llava-hf/llava-v1.6-mistral-7b-hf
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e PREFILL_BATCH_BUCKET_SIZE=1 \
-e BATCH_BUCKET_SIZE=1 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
--max-total-tokens 8192 --max-batch-size 4
```
## FP8 Precision
Please refer to the [FP8 Precision](https://huggingface.co/docs/text-generation-inference/backends/gaudi_new#how-to-use-different-precision-formats) section for more details. You need to measure the statistics of the model first before running the model in FP8 precision.
## Llama3.1-8B on 1 Card (FP8)
```bash
model=meta-llama/Meta-Llama-3.1-8B-Instruct
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e PREFILL_BATCH_BUCKET_SIZE=2 \
-e BATCH_BUCKET_SIZE=32 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
```
## Llama3.1-70B on 8 cards (FP8)
```bash
model=meta-llama/Meta-Llama-3.1-70B-Instruct
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e BATCH_BUCKET_SIZE=256 \
-e PREFILL_BATCH_BUCKET_SIZE=4 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
```
## Llama2-7B on 1 Card (FP8)
```bash
model=meta-llama/Llama-2-7b-chat-hf
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e PREFILL_BATCH_BUCKET_SIZE=2 \
-e BATCH_BUCKET_SIZE=32 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
```
## Llama2-70B on 8 Cards (FP8)
```bash
model=meta-llama/Llama-2-70b-chat-hf
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e BATCH_BUCKET_SIZE=256 \
-e PREFILL_BATCH_BUCKET_SIZE=4 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
```
## Llava-v1.6-Mistral-7B on 1 Card (FP8)
```bash
model=llava-hf/llava-v1.6-mistral-7b-hf
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e PREFILL_BATCH_BUCKET_SIZE=1 \
-e BATCH_BUCKET_SIZE=1 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
--max-total-tokens 8192 --max-batch-size 4
```
## Llava-v1.6-Mistral-7B on 8 Cards (FP8)
```bash
model=llava-hf/llava-v1.6-mistral-7b-hf
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e PREFILL_BATCH_BUCKET_SIZE=1 \
-e BATCH_BUCKET_SIZE=1 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
--max-total-tokens 8192 --max-batch-size 4
```

File diff suppressed because it is too large Load Diff

View File

@ -9,29 +9,30 @@ text-generation-server = 'text_generation_server.cli:app'
[tool.poetry.dependencies]
python = ">=3.9,<3.13"
protobuf = "^5.0"
grpcio = "^1.71.1"
protobuf = "^3.20.3"
grpcio = "^1.51.1"
grpcio-status = "*"
grpcio-reflection = "*"
grpc-interceptor = "^0.15.0"
typer = "^0.15.0"
loguru = "^0.7.3"
opentelemetry-api = "^1.32.0"
opentelemetry-exporter-otlp = "^1.32.0"
opentelemetry-instrumentation-grpc = "^0.53b0"
hf-transfer = "^0.1.9"
sentencepiece = "^0.2.0"
peft = "^0.15"
transformers = "^4.52.4"
numpy = "^1.26"
accelerate = "^1.7.0"
typer = "^0.7.0"
loguru = "^0.6.0"
opentelemetry-api = "^1.15.0"
opentelemetry-exporter-otlp = "^1.15.0"
opentelemetry-instrumentation-grpc = "^0.36b0"
hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97"
peft = "^0.10"
optimum-habana = "1.15.0"
transformers = "4.45.2"
numpy = "1.26.4"
accelerate = "0.33.0"
outlines= { version = "^0.0.36", optional = true }
prometheus-client = "^0.21.1"
prometheus-client = "^0.20.0"
py-cpuinfo = "^9.0.0"
[tool.poetry.group.dev.dependencies]
grpcio-tools = "*"
pytest = "^8.3.5"
pytest = "^7.3.0"
[tool.pytest.ini_options]
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
@ -39,6 +40,3 @@ markers = ["private: marks tests as requiring an admin hf token (deselect with '
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.poetry.requires-plugins]
poetry-plugin-export = ">=1.8"

View File

@ -1,86 +1,89 @@
accelerate==1.7.0 ; python_version >= "3.9" and python_version < "3.13"
annotated-types==0.7.0 ; python_version >= "3.9" and python_version < "3.13"
attrs==25.3.0 ; python_version >= "3.9" and python_version < "3.13"
certifi==2025.1.31 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.1 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.8 ; python_version >= "3.9" and python_version < "3.13"
cloudpickle==3.1.1 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Windows" or python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
deprecated==1.2.18 ; python_version >= "3.9" and python_version < "3.13"
accelerate==0.33.0 ; python_version >= "3.9" and python_version < "3.13"
aiohappyeyeballs==2.4.3 ; python_version >= "3.9" and python_version < "3.13"
aiohttp==3.10.10 ; python_version >= "3.9" and python_version < "3.13"
aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13"
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.11"
attrs==24.2.0 ; python_version >= "3.9" and python_version < "3.13"
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
coloredlogs==15.0.1 ; python_version >= "3.9" and python_version < "3.13"
datasets==3.0.1 ; python_version >= "3.9" and python_version < "3.13"
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
diffusers==0.31.0 ; python_version >= "3.9" and python_version < "3.13"
diskcache==5.6.3 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.18.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2025.3.2 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.70.0 ; python_version >= "3.9" and python_version < "3.13"
dill==0.3.7 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
frozenlist==1.4.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec[http]==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.71.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.71.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.72.0rc1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.9 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.30.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.48.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.48.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.26.1 ; python_version >= "3.9" and python_version < "3.13"
humanfriendly==10.0 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==8.6.1 ; python_version >= "3.9" and python_version < "3.13"
interegular==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
jinja2==3.1.6 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==8.5.0 ; python_version >= "3.9" and python_version < "3.13"
jinja2==3.1.4 ; python_version >= "3.9" and python_version < "3.13"
joblib==1.4.2 ; python_version >= "3.9" and python_version < "3.13"
jsonschema-specifications==2024.10.1 ; python_version >= "3.9" and python_version < "3.13"
jsonschema==4.23.0 ; python_version >= "3.9" and python_version < "3.13"
lark==1.2.2 ; python_version >= "3.9" and python_version < "3.13"
llvmlite==0.43.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.7.3 ; python_version >= "3.9" and python_version < "3.13"
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markupsafe==3.0.2 ; python_version >= "3.9" and python_version < "3.13"
mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13"
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
nest-asyncio==1.6.0 ; python_version >= "3.9" and python_version < "3.13"
multidict==6.1.0 ; python_version >= "3.9" and python_version < "3.13"
multiprocess==0.70.15 ; python_version >= "3.9" and python_version < "3.13"
networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
numba==0.60.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-common==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
optimum==1.24.0 ; python_version >= "3.9" and python_version < "3.13"
outlines==0.0.36 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.2 ; python_version >= "3.9" and python_version < "3.13"
peft==0.15.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==11.2.1 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
protobuf==5.29.4 ; python_version >= "3.9" and python_version < "3.13"
psutil==7.0.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
optimum-habana==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
optimum==1.23.2 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pandas==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
peft==0.10.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==11.0.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
propcache==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==3.20.3 ; python_version >= "3.9" and python_version < "3.13"
psutil==6.1.0 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pydantic-core==2.33.1 ; python_version >= "3.9" and python_version < "3.13"
pydantic==2.11.3 ; python_version >= "3.9" and python_version < "3.13"
pygments==2.19.1 ; python_version >= "3.9" and python_version < "3.13"
pyarrow==17.0.0 ; python_version >= "3.9" and python_version < "3.13"
pyreadline3==3.5.4 ; sys_platform == "win32" and python_version >= "3.9" and python_version < "3.13"
python-dateutil==2.9.0.post0 ; python_version >= "3.9" and python_version < "3.13"
pytz==2024.2 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
referencing==0.36.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.11.6 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==14.0.0 ; python_version >= "3.9" and python_version < "3.13"
rpds-py==0.24.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.5.3 ; python_version >= "3.9" and python_version < "3.13"
scikit-learn==1.6.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scikit-learn==1.5.2 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentence-transformers==3.3.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==78.1.0 ; python_version >= "3.12" and python_version < "3.13"
shellingham==1.5.4 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
threadpoolctl==3.6.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.67.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.52.4 ; python_version >= "3.9" and python_version < "3.13"
triton==3.2.0 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
typer==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.13.2 ; python_version >= "3.9" and python_version < "3.13"
typing-inspection==0.4.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.4.0 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.2.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.17.2 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.21.0 ; python_version >= "3.9" and python_version < "3.13"
sentence-transformers[train]==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
six==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.12.1 ; python_version >= "3.9" and python_version < "3.13"
threadpoolctl==3.5.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
transformers[sentencepiece]==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.13" and python_version >= "3.9"
typer==0.7.0 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
tzdata==2024.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
xxhash==3.5.0 ; python_version >= "3.9" and python_version < "3.13"
yarl==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -0,0 +1,23 @@
import pytest
import os
from text_generation_server.pb import generate_pb2
os.environ["USE_PREFIX_CACHING"] = "1"
os.environ["ATTENTION"] = "flashinfer"
@pytest.fixture
def default_pb_parameters():
return generate_pb2.NextTokenChooserParameters(
temperature=1.0,
repetition_penalty=1.0,
top_k=0,
top_p=1.0,
typical_p=1.0,
do_sample=False,
)
@pytest.fixture
def default_pb_stop_parameters():
return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10)

View File

@ -0,0 +1,363 @@
import pytest
import torch
from copy import copy
from transformers import AutoTokenizer
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.utils import weight_hub_files, download_weights
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
@pytest.fixture(scope="session")
def default_bloom():
model_id = "bigscience/bloom-560m"
revision = "main"
filenames = weight_hub_files(model_id, revision, ".safetensors")
download_weights(filenames, model_id, revision)
return BLOOMSharded(
model_id,
model_class=BloomForCausalLM,
)
@pytest.fixture(scope="session")
def bloom_560m_tokenizer():
return AutoTokenizer.from_pretrained("bigscience/bloom-560m", padding_side="left")
@pytest.fixture
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@pytest.fixture
def default_pb_batch(default_pb_request):
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)
@pytest.fixture
def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer):
return BloomCausalLMBatch.from_pb(
default_pb_batch, bloom_560m_tokenizer, torch.float32, torch.device("cpu")
)
@pytest.fixture
def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer):
req_0 = copy(default_pb_request)
req_0.id = 1
req_1 = default_pb_request
req_1.id = 2
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
return BloomCausalLMBatch.from_pb(
batch_pb, bloom_560m_tokenizer, torch.float32, torch.device("cpu")
)
def test_batch_from_pb(default_pb_batch, default_bloom_batch):
batch = default_bloom_batch
assert batch.batch_id == default_pb_batch.id
assert batch.requests == default_pb_batch.requests
assert len(batch.input_ids) == default_pb_batch.size
assert batch.input_ids[0][-1] == 10264
assert torch.all(batch.input_ids[0][:-1] == 3)
assert batch.attention_mask[0][0] == 1
assert torch.all(batch.attention_mask[0][1:] == 0)
assert batch.past_key_values is None
assert all(
[
torch.equal(input_ids, all_input_ids[:, 0])
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
]
)
assert batch.input_lengths == [1]
assert len(batch) == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
assert batch.max_input_length == batch.input_lengths[0]
def test_batch_concatenate_no_prefill(default_bloom_batch):
with pytest.raises(ValueError):
BloomCausalLMBatch.concatenate([default_bloom_batch, default_bloom_batch])
def test_causal_lm_batch_type(default_bloom):
assert default_bloom.batch_type == BloomCausalLMBatch
def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
sequence_length = len(default_bloom_batch.all_input_ids[0])
generations, next_batch, _ = default_bloom.generate_token(default_bloom_batch)
assert len(generations) == len(default_bloom_batch)
assert isinstance(next_batch, CausalLMBatch)
assert not next_batch.keys_head_dim_last
assert len(next_batch.all_input_ids) == len(next_batch)
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
assert len(next_batch.attention_mask[0]) == 11
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
assert torch.all(next_batch.all_input_ids[0][:-2] == 3)
assert torch.all(next_batch.attention_mask[0][:2] == 1)
assert torch.all(next_batch.attention_mask[0][2:] == 0)
assert next_batch.input_ids.shape == (len(next_batch), 1)
assert next_batch.input_ids[0, 0] == 10264
assert next_batch.input_lengths == [2]
assert next_batch.max_input_length == next_batch.input_lengths[0]
assert next_batch.past_key_values is not None
assert all(
[p[0].shape == (16, 64, sequence_length) for p in next_batch.past_key_values]
)
assert all(
[p[1].shape == (16, sequence_length, 64) for p in next_batch.past_key_values]
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all(
[
token_id.item() == 10264
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == "Test"
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
next_batch = default_bloom_batch
for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(default_bloom_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert (
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
)
assert generations[0].request_id == default_bloom_batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
== default_bloom_batch.stopping_criterias[0].max_new_tokens
)
def test_causal_lm_generate_token_completion_multi(
default_bloom, default_multi_requests_bloom_batch
):
next_batch = default_multi_requests_bloom_batch
for i in range(
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(default_multi_requests_bloom_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
assert generations[1].generated_text.text == "TestTestTestTestTest"
assert (
generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id
)
assert (
generations[1].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
)
# Copy stopping_criterias before filtering
stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy()
next_batch = next_batch.filter([next_batch.requests[0].id])
for _ in range(
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert (
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
)
assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
)
assert (
generations[0].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
)
def test_batch_concatenate(
default_bloom, default_bloom_batch, default_multi_requests_bloom_batch
):
next_batch_0 = default_bloom_batch
_, next_batch_0, _ = default_bloom.generate_token(next_batch_0)
_, next_batch_0, _ = default_bloom.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_bloom_batch
_, next_batch_1, _ = default_bloom.generate_token(next_batch_1)
# Clone past_key_values before concatenating to compare after,
# because they are removed from the concatenated batches
next_batch_0_past_key_values = [
(k.clone(), v.clone()) for (k, v) in next_batch_0.past_key_values
]
next_batch_1_past_key_values = [
(k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values
]
next_batch = BloomCausalLMBatch.concatenate([next_batch_0, next_batch_1])
assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
assert torch.all(
next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1
)
assert torch.all(
next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1
)
assert torch.all(next_batch.attention_mask[1:, 3:] == 0)
assert next_batch.batch_id == 0
assert torch.all(next_batch.input_ids == 10264)
assert next_batch.input_lengths == [3, 2, 2]
assert next_batch.max_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == list(next_batch_1.requests)
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers
assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]
assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias
assert next_batch.past_key_values is not None
assert all([p[0].shape == (3, 16, 64, 2) for p in next_batch.past_key_values])
assert all([p[1].shape == (3, 16, 2, 64) for p in next_batch.past_key_values])
for i, past in enumerate(next_batch.past_key_values):
assert torch.equal(next_batch_0_past_key_values[i][0][:, :, -2:], past[0][0])
assert torch.equal(
next_batch_1_past_key_values[i][0][:, :, -1:],
past[0][1:, :, :, -1].reshape(-1, 64, 1),
)
assert torch.equal(next_batch_0_past_key_values[i][1][:, -2:, :], past[1][0])
assert torch.equal(
next_batch_1_past_key_values[i][1][:, -1:, :],
past[1][1:, :, -1, :].reshape(-1, 1, 64),
)
for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2
):
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 3
assert generations[2].generated_text.text == "TestTestTestTestTest"
assert (
generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id
)
assert (
generations[2].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
for _ in range(
default_bloom_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 2
):
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
assert (
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
)
assert generations[0].request_id == default_bloom_batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
== default_bloom_batch.stopping_criterias[0].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[1].id])
for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
- default_bloom_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 4
):
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert (
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
)
assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
)
assert (
generations[0].generated_text.generated_tokens
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
)

View File

@ -0,0 +1,354 @@
import pytest
import torch
from copy import copy
from transformers import AutoTokenizer
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
@pytest.fixture(scope="session")
def default_causal_lm():
return CausalLM.fallback("gpt2")
@pytest.fixture(scope="session")
def gpt2_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
tokenizer.pad_token_id = 50256
return tokenizer
@pytest.fixture
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@pytest.fixture
def default_pb_batch(default_pb_request):
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)
@pytest.fixture
def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
return CausalLMBatch.from_pb(
default_pb_batch, gpt2_tokenizer, torch.float32, torch.device("cpu")
)
@pytest.fixture
def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
req_0 = copy(default_pb_request)
req_0.id = 1
req_1 = default_pb_request
req_1.id = 2
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)
return CausalLMBatch.from_pb(
batch_pb, gpt2_tokenizer, torch.float32, torch.device("cpu")
)
def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
batch = default_causal_lm_batch
assert batch.batch_id == default_pb_batch.id
assert batch.requests == default_pb_batch.requests
assert len(batch.input_ids) == default_pb_batch.size
assert batch.input_ids[0][-1] == 14402
assert torch.all(batch.input_ids[0][:-1] == 50256)
assert batch.attention_mask[0, 0] == 1
assert torch.all(batch.attention_mask[0, 1:] == 0)
assert batch.past_key_values is None
assert all(
[
torch.equal(input_ids, all_input_ids[:, 0])
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
]
)
assert batch.input_lengths == [1]
assert len(batch) == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
assert batch.max_input_length == batch.input_lengths[0]
def test_batch_concatenate_no_prefill(default_causal_lm_batch):
with pytest.raises(ValueError):
CausalLMBatch.concatenate([default_causal_lm_batch, default_causal_lm_batch])
def test_causal_lm_batch_type(default_causal_lm):
assert default_causal_lm.batch_type == CausalLMBatch
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
sequence_length = len(default_causal_lm_batch.all_input_ids[0])
generations, next_batch, _ = default_causal_lm.generate_token(
default_causal_lm_batch
)
assert len(generations) == len(next_batch)
assert isinstance(next_batch, CausalLMBatch)
assert len(next_batch.all_input_ids) == len(next_batch)
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
assert len(next_batch.attention_mask[0]) == 11
assert next_batch.all_input_ids[0][-1] == 13
assert next_batch.all_input_ids[0][-2] == 14402
assert torch.all(next_batch.all_input_ids[0][:-2] == 50256)
assert torch.all(next_batch.attention_mask[0][0:2] == 1)
assert torch.all(next_batch.attention_mask[0][2:] == 0)
assert next_batch.input_ids.shape == (len(next_batch), 1)
assert next_batch.input_ids[0, 0] == 13
assert next_batch.input_lengths == [2]
assert next_batch.max_input_length == next_batch.input_lengths[0]
assert next_batch.past_key_values is not None
assert all(
[p[0].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
)
assert all(
[p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all(
[
token_id.item() == 13
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == "."
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0
def test_causal_lm_generate_token_completion(
default_causal_lm, default_causal_lm_batch
):
next_batch = default_causal_lm_batch
for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
)
def test_causal_lm_generate_token_completion_multi(
default_causal_lm, default_multi_requests_causal_lm_batch
):
next_batch = default_multi_requests_causal_lm_batch
for i in range(
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
assert generations[1].generated_text.text == ".java:784)"
assert (
generations[1].request_id
== default_multi_requests_causal_lm_batch.requests[1].id
)
assert (
generations[1].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
# Copy stopping_criterias before filtering
stopping_criterias = (
default_multi_requests_causal_lm_batch.stopping_criterias.copy()
)
next_batch = next_batch.filter([next_batch.requests[0].id])
for _ in range(
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert (
generations[0].request_id
== default_multi_requests_causal_lm_batch.requests[0].id
)
assert (
generations[0].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
)
def test_batch_concatenate(
default_causal_lm, default_causal_lm_batch, default_multi_requests_causal_lm_batch
):
next_batch_0 = default_causal_lm_batch
_, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0)
_, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_causal_lm_batch
_, next_batch_1, _ = default_causal_lm.generate_token(next_batch_1)
# Clone past_key_values before concatenating to compare after,
# because they are removed from the concatenated batches
next_batch_0_past_key_values = [
(k.clone(), v.clone()) for (k, v) in next_batch_0.past_key_values
]
next_batch_1_past_key_values = [
(k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values
]
next_batch = CausalLMBatch.concatenate([next_batch_0, next_batch_1])
assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
assert torch.all(
next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1
)
assert torch.all(
next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1
)
assert torch.all(next_batch.attention_mask[1:, 3:] == 0)
assert next_batch.batch_id == 0
assert next_batch.input_ids[0, 0] == 12355
assert torch.all(next_batch.input_ids[1:] == 13)
assert next_batch.input_lengths == [3, 2, 2]
assert next_batch.max_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == list(next_batch_1.requests)
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers
assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]
assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias
assert next_batch.past_key_values is not None
assert all([p[0].shape == (3, 12, 2, 64) for p in next_batch.past_key_values])
assert all([p[1].shape == (3, 12, 2, 64) for p in next_batch.past_key_values])
for i, past in enumerate(next_batch.past_key_values):
assert torch.equal(next_batch_0_past_key_values[i][0][0, :, -2:], past[0][0])
assert torch.equal(
next_batch_1_past_key_values[i][0][:, :, -1:], past[0][1:, :, -1:, :]
)
assert torch.equal(next_batch_0_past_key_values[i][1][0, :, -2:], past[1][0])
assert torch.equal(
next_batch_1_past_key_values[i][1][:, :, -1:], past[1][1:, :, -1:, :]
)
for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2
):
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 3
assert generations[2].generated_text.text == ".java:784)"
assert (
generations[2].request_id
== default_multi_requests_causal_lm_batch.requests[1].id
)
assert (
generations[2].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
for _ in range(
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 2
):
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[1].id])
for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
- default_causal_lm_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 4
):
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert (
generations[0].request_id
== default_multi_requests_causal_lm_batch.requests[0].id
)
assert (
generations[0].generated_text.generated_tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
)

View File

@ -0,0 +1,83 @@
import pytest
import torch
from transformers import AutoTokenizer
from text_generation_server.models import Model
def get_test_model():
class TestModel(Model):
def batch_type(self):
raise NotImplementedError
def generate_token(self, batch):
raise NotImplementedError
tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")
model = TestModel(
"test_model_id",
torch.nn.Linear(1, 1),
tokenizer,
False,
torch.float32,
torch.device("cpu"),
)
return model
@pytest.mark.private
def test_decode_streaming_english_spaces():
model = get_test_model()
truth = "Hello here, this is a simple test"
all_input_ids = [15043, 1244, 29892, 445, 338, 263, 2560, 1243]
assert (
all_input_ids == model.tokenizer(truth, add_special_tokens=False)["input_ids"]
)
decoded_text = ""
offset = 0
token_offset = 0
for i in range(len(all_input_ids)):
text, offset, token_offset = model.decode_token(
all_input_ids[: i + 1], offset, token_offset
)
decoded_text += text
assert decoded_text == truth
@pytest.mark.private
def test_decode_streaming_chinese_utf8():
model = get_test_model()
truth = "我很感谢你的热情"
all_input_ids = [
30672,
232,
193,
139,
233,
135,
162,
235,
179,
165,
30919,
30210,
234,
134,
176,
30993,
]
decoded_text = ""
offset = 0
token_offset = 0
for i in range(len(all_input_ids)):
text, offset, token_offset = model.decode_token(
all_input_ids[: i + 1], offset, token_offset
)
decoded_text += text
assert decoded_text == truth

View File

@ -0,0 +1,108 @@
import pytest
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch, CausalLM
@pytest.fixture(scope="session")
def default_santacoder():
return CausalLM.fallback(model_id="bigcode/santacoder")
@pytest.fixture
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="def",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="def")]),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@pytest.fixture
def default_pb_batch(default_pb_request):
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)
@pytest.fixture
def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
input_chunks=generate_pb2.Input(
chunks=[
generate_pb2.InputChunk(
text="<fim-prefix>def<fim-suffix>world<fim-middle>"
)
]
),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@pytest.fixture
def default_fim_pb_batch(default_fim_pb_request):
return generate_pb2.Batch(id=0, requests=[default_fim_pb_request], size=1)
@pytest.mark.skip
def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch):
batch = CausalLMBatch.from_pb(
default_pb_batch,
default_santacoder.tokenizer,
default_santacoder.dtype,
default_santacoder.device,
)
next_batch = batch
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == " test_get_all_users_with_"
assert generations[0].request_id == batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
== batch.stopping_criterias[0].max_new_tokens
)
@pytest.mark.skip
def test_fim_santacoder_generate_token_completion(
default_santacoder, default_fim_pb_batch
):
batch = CausalLMBatch.from_pb(
default_fim_pb_batch,
default_santacoder.tokenizer,
default_santacoder.dtype,
default_santacoder.device,
)
next_batch = batch
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert (
generations[0].generated_text.text
== """ineProperty(exports, "__esModule", { value"""
)
assert generations[0].request_id == batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
== batch.stopping_criterias[0].max_new_tokens
)

View File

@ -0,0 +1,365 @@
import pytest
import torch
from copy import copy
from transformers import AutoTokenizer
from text_generation_server.pb import generate_pb2
from text_generation_server.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
@pytest.fixture(scope="session")
def mt0_small_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(
"bigscience/mt0-small", padding_side="left"
)
tokenizer.bos_token_id = 0
return tokenizer
@pytest.fixture(scope="session")
def default_seq2seq_lm():
return Seq2SeqLM.fallback("bigscience/mt0-small")
@pytest.fixture
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@pytest.fixture
def default_pb_batch(default_pb_request):
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)
@pytest.fixture
def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer):
return Seq2SeqLMBatch.from_pb(
default_pb_batch, mt0_small_tokenizer, torch.float32, torch.device("cpu")
)
@pytest.fixture
def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokenizer):
req_0 = copy(default_pb_request)
req_0.id = 1
req_1 = default_pb_request
req_1.id = 2
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
return Seq2SeqLMBatch.from_pb(
batch_pb, mt0_small_tokenizer, torch.float32, torch.device("cpu")
)
def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
batch = default_seq2seq_lm_batch
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
assert batch.batch_id == default_pb_batch.id
assert batch.requests == default_pb_batch.requests
assert batch.input_ids.shape == (default_pb_batch.size, sequence_length)
assert batch.input_ids[0][-2] == 4268
assert batch.input_ids[0][-1] == 1
assert torch.all(batch.input_ids[0][:-2] == 0)
assert torch.all(batch.attention_mask[0][-2:] == 1)
assert torch.all(batch.attention_mask[0][:-2] == 0)
assert len(batch.decoder_input_ids) == default_pb_batch.size
assert batch.decoder_attention_mask is None
assert batch.encoder_last_hidden_state is None
assert batch.past_key_values is None
assert batch.input_lengths == [2]
assert batch.decoder_input_lengths == [1]
assert len(batch) == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
assert batch.max_input_length == batch.input_lengths[0]
assert batch.max_decoder_input_length == batch.decoder_input_lengths[0]
def test_batch_concatenate_no_prefill(default_seq2seq_lm_batch):
with pytest.raises(ValueError):
Seq2SeqLMBatch.concatenate([default_seq2seq_lm_batch, default_seq2seq_lm_batch])
def test_seq2seq_lm_batch_type(default_seq2seq_lm):
assert default_seq2seq_lm.batch_type == Seq2SeqLMBatch
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
generations, next_batch, _ = default_seq2seq_lm.generate_token(
default_seq2seq_lm_batch
)
assert len(generations) == len(next_batch)
assert isinstance(next_batch, Seq2SeqLMBatch)
assert next_batch.input_ids is None
assert torch.equal(
next_batch.attention_mask, default_seq2seq_lm_batch.attention_mask
)
assert next_batch.input_lengths == default_seq2seq_lm_batch.input_lengths
assert next_batch.max_input_length == default_seq2seq_lm_batch.max_input_length
assert (
next_batch.next_token_choosers == default_seq2seq_lm_batch.next_token_choosers
)
assert next_batch.stopping_criterias == default_seq2seq_lm_batch.stopping_criterias
assert len(next_batch.decoder_input_ids) == len(next_batch)
assert next_batch.all_decoder_input_ids[0][0] == 0
assert next_batch.all_decoder_input_ids[0][1] == 259
assert next_batch.decoder_attention_mask is None
assert next_batch.encoder_last_hidden_state.shape == (1, sequence_length, 512)
assert next_batch.decoder_input_lengths == [2]
assert next_batch.max_decoder_input_length == 2
assert next_batch.past_key_values is not None
assert all(
[p[0].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values]
)
assert all(
[p[1].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values]
)
assert all(
[
p[2].shape == (len(next_batch), 6, sequence_length, 64)
for p in next_batch.past_key_values
]
)
assert all(
[
p[3].shape == (len(next_batch), 6, sequence_length, 64)
for p in next_batch.past_key_values
]
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all(
[
token_id.item() == 259
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == " "
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0
def test_seq2seq_lm_generate_token_completion(
default_seq2seq_lm, default_seq2seq_lm_batch
):
next_batch = default_seq2seq_lm_batch
for _ in range(6):
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == "a few weeks"
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
assert generations[0].generated_text.generated_tokens == 7
def test_seq2seq_lm_generate_token_completion_multi(
default_seq2seq_lm, default_multi_requests_seq2seq_lm_batch
):
next_batch = default_multi_requests_seq2seq_lm_batch
for i in range(4):
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
assert generations[1].generated_text.text == "a few "
assert (
generations[1].request_id
== default_multi_requests_seq2seq_lm_batch.requests[1].id
)
assert generations[1].generated_text.generated_tokens == 5
next_batch = next_batch.filter([next_batch.requests[0].id])
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == "a few weeks"
assert (
generations[0].request_id
== default_multi_requests_seq2seq_lm_batch.requests[0].id
)
assert generations[0].generated_text.generated_tokens == 7
def test_batch_concatenate(
default_seq2seq_lm,
default_seq2seq_lm_batch,
default_multi_requests_seq2seq_lm_batch,
):
next_batch_0 = default_seq2seq_lm_batch
_, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
_, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_seq2seq_lm_batch
_, next_batch_1, _ = default_seq2seq_lm.generate_token(next_batch_1)
# Copy hidden state because it is removed from the concatenated branches
next_batch_0_encoder_last_hidden_state = next_batch_0.encoder_last_hidden_state
next_batch_1_encoder_last_hidden_state = next_batch_1.encoder_last_hidden_state
# Clone past_key_values before concatenating to compare after,
# because they are removed from the concatenated batches
next_batch_0_past_key_values = [
[t.clone() for t in layer] for layer in next_batch_0.past_key_values
]
next_batch_1_past_key_values = [
[t.clone() for t in layer] for layer in next_batch_1.past_key_values
]
next_batch = Seq2SeqLMBatch.concatenate([next_batch_0, next_batch_1])
assert next_batch.batch_id == 0
assert torch.equal(
next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0]
)
assert next_batch.all_decoder_input_ids[1][0] == 0
assert next_batch.all_decoder_input_ids[2][0] == 0
assert torch.equal(
next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids
)
assert torch.all(next_batch.decoder_attention_mask[0, :3] == 1)
assert torch.all(next_batch.decoder_attention_mask[0, 3:] == 0)
assert torch.all(next_batch.decoder_attention_mask[1:, 0] == 0)
assert torch.all(next_batch.decoder_attention_mask[1:, 1:3] == 1)
assert torch.equal(
next_batch.encoder_last_hidden_state[0],
next_batch_0_encoder_last_hidden_state[0, -2:],
)
assert torch.equal(
next_batch.encoder_last_hidden_state[1:],
next_batch_1_encoder_last_hidden_state[:, -2:],
)
assert next_batch.input_lengths == [2, 2, 2]
assert next_batch.decoder_input_lengths == [3, 2, 2]
assert next_batch.max_input_length == 2
assert next_batch.max_decoder_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == list(next_batch_1.requests)
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers
assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]
assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias
assert next_batch.past_key_values is not None
assert all(
[p[0].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
)
assert all(
[p[1].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
)
assert all(
[p[2].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
)
assert all(
[p[3].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
)
for i, past in enumerate(next_batch.past_key_values):
assert torch.equal(next_batch_0_past_key_values[i][0][0, :, -2:, :], past[0][0])
assert torch.equal(
next_batch_1_past_key_values[i][0][:, :, -1:, :], past[0][1:, :, -1:, :]
)
assert torch.equal(next_batch_0_past_key_values[i][1][0, :, -2:, :], past[1][0])
assert torch.equal(
next_batch_1_past_key_values[i][1][:, :, -1:, :], past[1][1:, :, -1:, :]
)
assert torch.equal(next_batch_0_past_key_values[i][2][0, :, -2:, :], past[2][0])
assert torch.equal(
next_batch_1_past_key_values[i][2][:, :, -2:, :], past[2][1:]
)
assert torch.equal(next_batch_0_past_key_values[i][3][0, :, -2:, :], past[3][0])
assert torch.equal(
next_batch_1_past_key_values[i][3][:, :, -2:, :], past[3][1:]
)
for _ in range(3):
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 3
assert generations[2].generated_text.text == "a few "
assert (
generations[2].request_id
== default_multi_requests_seq2seq_lm_batch.requests[1].id
)
assert generations[2].generated_text.generated_tokens == 5
next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
assert generations[0].generated_text.text == "a few weeks"
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
assert generations[0].generated_text.generated_tokens == 7
next_batch = next_batch.filter([next_batch.requests[1].id])
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == "a few weeks"
assert (
generations[0].request_id
== default_multi_requests_seq2seq_lm_batch.requests[0].id
)
assert generations[0].generated_text.generated_tokens == 7

View File

@ -0,0 +1,247 @@
import pytest
from unittest.mock import Mock
from text_generation_server.utils.adapter import (
get_attn_weights,
get_mlp_weights,
parse_lora_adapters,
AdapterInfo,
)
def test_parse_lora_adapters_empty():
assert parse_lora_adapters(None) == []
assert parse_lora_adapters("") == []
def test_parse_lora_adapters_single():
result = parse_lora_adapters("adapter1")
assert result == [AdapterInfo(id="adapter1", path=None, revision=None)]
def test_parse_lora_adapters_with_path():
result = parse_lora_adapters("adapter1=path/to/adapter1")
assert result == [
AdapterInfo(id="adapter1", path="path/to/adapter1", revision=None)
]
def test_parse_lora_adapters_with_path_and_revision():
result = parse_lora_adapters("adapter1=path/to/adapter1@main")
assert result == [
AdapterInfo(id="adapter1", path="path/to/adapter1", revision="main")
]
def test_parse_lora_adapters_multiple():
result = parse_lora_adapters(
"adapter1,adapter2=path/to/adapter2,adapter3=path/to/adapter3@dev"
)
assert result == [
AdapterInfo(id="adapter1", path=None, revision=None),
AdapterInfo(id="adapter2", path="path/to/adapter2", revision=None),
AdapterInfo(id="adapter3", path="path/to/adapter3", revision="dev"),
]
def test_parse_lora_adapters_invalid_format():
try:
parse_lora_adapters("adapter1,invalid=format=test,adapter3")
assert False, "Should have raised ValueError"
except ValueError as e:
assert str(e) == "Invalid LoRA adapter format: invalid=format=test"
def test_get_attn_weights():
# create a mock layer
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()
# call the function
result = get_attn_weights(2, mock_layer)
# assert the result
expected = {
(2, "q_proj"): (
"model.layers.2.self_attn.q_proj",
mock_layer.self_attn.query_key_value,
),
(2, "k_proj"): (
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "qkv_proj"): (
"model.layers.2.self_attn.qkv_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,
),
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
}
assert result == expected
def test_get_mlp_weights_with_gate_up_proj():
# create a mock layer with gate_up_proj
mock_layer = Mock()
mock_layer.mlp.gate_up_proj = Mock()
mock_layer.mlp.down_proj = Mock()
# call the function
result = get_mlp_weights(3, mock_layer)
# assert the result
expected = {
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
}
assert result == expected
def test_get_mlp_weights_without_gate_up_proj():
# create a mock layer without gate_up_proj
mock_layer = Mock()
mock_layer.mlp = Mock(spec=[])
# call the function
result = get_mlp_weights(1, mock_layer)
# assert the result
assert result == {}
@pytest.mark.parametrize("layer_index", [0, 1, 5])
def test_get_attn_weights_different_layers(layer_index):
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()
result = get_attn_weights(layer_index, mock_layer)
for k in ["q", "k", "v"]:
assert (layer_index, f"{k}_proj") in result
assert (
result[(layer_index, f"{k}_proj")][0]
== f"model.layers.{layer_index}.self_attn.{k}_proj"
)
assert (layer_index, "o_proj") in result
assert (
result[(layer_index, "o_proj")][0]
== f"model.layers.{layer_index}.self_attn.o_proj"
)
@pytest.mark.parametrize("layer_index", [0, 1, 5])
def test_get_mlp_weights_different_layers(layer_index):
mock_layer = Mock()
mock_layer.mlp.gate_up_proj = Mock()
mock_layer.mlp.down_proj = Mock()
result = get_mlp_weights(layer_index, mock_layer)
for k in ["gate", "up", "down"]:
assert (layer_index, f"{k}_proj") in result
assert (
result[(layer_index, f"{k}_proj")][0]
== f"model.layers.{layer_index}.mlp.{k}_proj"
)
def test_get_attn_weights_llama_compatibility():
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()
result = get_attn_weights(2, mock_layer)
expected = {
(2, "q_proj"): (
"model.layers.2.self_attn.q_proj",
mock_layer.self_attn.query_key_value,
),
(2, "k_proj"): (
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "qkv_proj"): (
"model.layers.2.self_attn.qkv_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,
),
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
}
assert result == expected
def test_get_mlp_weights_llama_compatibility():
mock_layer = Mock()
mock_layer.mlp.gate_up_proj = Mock()
mock_layer.mlp.down_proj = Mock()
result = get_mlp_weights(3, mock_layer)
expected = {
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
}
assert result == expected
def test_get_attn_weights_gemma_compatibility():
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()
result = get_attn_weights(2, mock_layer)
expected = {
(2, "q_proj"): (
"model.layers.2.self_attn.q_proj",
mock_layer.self_attn.query_key_value,
),
(2, "k_proj"): (
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "qkv_proj"): (
"model.layers.2.self_attn.qkv_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,
),
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
}
assert result == expected
def test_get_mlp_weights_gemma_compatibility():
mock_layer = Mock()
mock_layer.mlp.gate_proj = Mock()
mock_layer.mlp.up_proj = Mock()
mock_layer.mlp.down_proj = Mock()
# ensure that the mock_layer.mlp.gate_up_proj attribute does not exist.
# This is necessary because the use of `Mock` automatically creates any
# attributes that are accessed, even if they don't exist in the actual
# implementation. If `gate_up_proj` were created, `get_mlp_weights` might
# follow the wrong execution path and return an incorrect result.
del mock_layer.mlp.gate_up_proj
result = get_mlp_weights(3, mock_layer)
expected = {
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
}
assert result == expected

View File

@ -0,0 +1,21 @@
from text_generation_server.utils.hub import (
download_weights,
weight_hub_files,
weight_files,
)
from text_generation_server.utils.convert import convert_files
def test_convert_files():
model_id = "bigscience/bloom-560m"
pt_filenames = weight_hub_files(model_id, extension=".bin")
local_pt_files = download_weights(pt_filenames, model_id)
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
]
convert_files(local_pt_files, local_st_files, discard_names=[])
found_st_files = weight_files(model_id)
assert all([p in found_st_files for p in local_st_files])

View File

@ -0,0 +1,103 @@
import os
import tempfile
import pytest
import huggingface_hub.constants
import text_generation_server.utils.hub
from text_generation_server.utils.hub import (
weight_hub_files,
download_weights,
weight_files,
EntryNotFoundError,
LocalEntryNotFoundError,
RevisionNotFoundError,
)
@pytest.fixture()
def offline():
current_value = text_generation_server.utils.hub.HF_HUB_OFFLINE
text_generation_server.utils.hub.HF_HUB_OFFLINE = True
yield "offline"
text_generation_server.utils.hub.HF_HUB_OFFLINE = current_value
@pytest.fixture()
def fresh_cache():
with tempfile.TemporaryDirectory() as d:
current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d
os.environ["HUGGINGFACE_HUB_CACHE"] = d
yield
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value
os.environ["HUGGINGFACE_HUB_CACHE"] = current_value
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value
@pytest.fixture()
def prefetched():
model_id = "bert-base-uncased"
huggingface_hub.snapshot_download(
repo_id=model_id,
revision="main",
local_files_only=False,
repo_type="model",
allow_patterns=["*.safetensors"],
)
yield model_id
def test_weight_hub_files_offline_error(offline, fresh_cache):
# If the model is not prefetched then it will raise an error
with pytest.raises(EntryNotFoundError):
weight_hub_files("gpt2")
def test_weight_hub_files_offline_ok(prefetched, offline):
# If the model is prefetched then we should be able to get the weight files from local cache
filenames = weight_hub_files(prefetched)
root = None
assert len(filenames) == 1
for f in filenames:
curroot, filename = os.path.split(f)
if root is None:
root = curroot
else:
assert root == curroot
assert filename == "model.safetensors"
def test_weight_hub_files():
filenames = weight_hub_files("bigscience/bloom-560m")
assert filenames == ["model.safetensors"]
def test_weight_hub_files_llm():
filenames = weight_hub_files("bigscience/bloom")
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
def test_weight_hub_files_empty():
with pytest.raises(EntryNotFoundError):
weight_hub_files("bigscience/bloom", extension=".errors")
def test_download_weights():
model_id = "bigscience/bloom-560m"
filenames = weight_hub_files(model_id)
files = download_weights(filenames, model_id)
local_files = weight_files("bigscience/bloom-560m")
assert files == local_files
def test_weight_files_revision_error():
with pytest.raises(RevisionNotFoundError):
weight_files("bigscience/bloom-560m", revision="error")
def test_weight_files_not_cached_error(fresh_cache):
with pytest.raises(LocalEntryNotFoundError):
weight_files("bert-base-uncased")

View File

@ -0,0 +1,82 @@
import torch
from text_generation_server.layers import (
TensorParallelEmbedding,
)
class ProcessGroup:
def __init__(self, rank: int, world_size: int):
self._rank = rank
self.world_size = world_size
def size(self) -> int:
return self.world_size
def rank(self) -> int:
return self._rank
class Weights:
def __init__(self, rank: int, world_size: int, vocab_size: int, hidden_dim: int):
self.weight = (
torch.arange(vocab_size * hidden_dim).float().view(vocab_size, hidden_dim)
)
self.process_group = ProcessGroup(rank, world_size)
def get_partial_sharded(self, name: str, dim: int):
assert dim == 0
rank = self.process_group.rank()
world_size = self.process_group.size()
size = self.weight.shape[dim]
block_size = (size + world_size - 1) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
return self.weight[start:stop]
def get_shape(self, name: str):
return self.weight.shape
def test_weight_hub_files_offline_error():
vocab_size = 17
weights = Weights(
rank=0,
world_size=1,
vocab_size=vocab_size,
hidden_dim=256,
)
embeddings = TensorParallelEmbedding("", weights)
input_ids = torch.arange(vocab_size)
output = embeddings.forward(input_ids)
assert embeddings.min_id == 0
assert embeddings.max_id == 17
torch.testing.assert_close(output, torch.arange(256 * 17).float().view(17, 256))
weights_0_2 = Weights(rank=0, world_size=2, vocab_size=vocab_size, hidden_dim=256)
weights_1_2 = Weights(rank=1, world_size=2, vocab_size=vocab_size, hidden_dim=256)
embeddings_0_2 = TensorParallelEmbedding("", weights_0_2, reduce=False)
assert embeddings_0_2.min_id == 0
assert embeddings_0_2.max_id == 9
torch.testing.assert_close(
embeddings_0_2.weight,
torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0)
.view(10, 256)
.float(),
)
embeddings_1_2 = TensorParallelEmbedding("", weights_1_2, reduce=False)
assert embeddings_1_2.min_id == 9
assert embeddings_1_2.max_id == 17
torch.testing.assert_close(
embeddings_1_2.weight,
torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0)
.view(9, 256)
.float(),
)
output_tp_0 = embeddings_0_2.forward(input_ids)
output_tp_1 = embeddings_1_2.forward(input_ids)
torch.testing.assert_close(output, output_tp_0 + output_tp_1)

View File

@ -0,0 +1,88 @@
import torch
from text_generation_server.utils.tokens import (
StopSequenceCriteria,
StoppingCriteria,
FinishReason,
batch_top_tokens,
)
def test_stop_sequence_criteria():
criteria = StopSequenceCriteria("/test;")
assert not criteria("/")
assert not criteria("/test")
assert criteria("/test;")
assert not criteria("/test; ")
def test_stop_sequence_criteria_escape():
criteria = StopSequenceCriteria("<|stop|>")
assert not criteria("<")
assert not criteria("<|stop")
assert criteria("<|stop|>")
assert not criteria("<|stop|> ")
def test_stopping_criteria():
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(65827, "/test") == (False, None)
assert criteria(30, ";") == (True, FinishReason.FINISH_REASON_STOP_SEQUENCE)
def test_stopping_criteria_eos():
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(1, "") == (False, None)
assert criteria(0, "") == (True, FinishReason.FINISH_REASON_EOS_TOKEN)
def test_stopping_criteria_max():
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
def test_batch_top_tokens():
top_n_tokens = [0, 2, 3, 4, 5]
top_n_tokens_tensor = torch.tensor(top_n_tokens)
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5)
accepted_ids = torch.ones_like(top_n_tokens_tensor)
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
)
assert topn_tok_ids[0] == [[]]
assert topn_tok_ids[1] == [[0, 3]]
assert topn_tok_ids[2] == [[0, 3, 1, 4]]
assert topn_tok_ids[3] == [[0, 3, 1, 4]]
assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]
assert topn_tok_logprobs[0] == [[]]
assert topn_tok_logprobs[1] == [[-1, -2]]
assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]
assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]
assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]
# Now let's make second member of the batch be speculated
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5 * 2)
accepted_ids[1] = 2
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
)
assert topn_tok_ids[0] == [[]]
assert topn_tok_ids[1] == [[0, 3], [0, 3]]
assert topn_tok_ids[2] == [[0, 3, 1, 4]]
assert topn_tok_ids[3] == [[0, 3, 1, 4]]
assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]
assert topn_tok_logprobs[0] == [[]]
assert topn_tok_logprobs[1] == [[-1, -2], [-1, -2]]
assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]
assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]
assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]

View File

@ -0,0 +1,54 @@
# test_watermark_logits_processor.py
import os
import numpy as np
import torch
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
GAMMA = os.getenv("WATERMARK_GAMMA", 0.5)
DELTA = os.getenv("WATERMARK_DELTA", 2.0)
def test_seed_rng():
input_ids = [101, 2036, 3731, 102, 2003, 103]
processor = WatermarkLogitsProcessor()
processor._seed_rng(input_ids)
assert isinstance(processor.rng, torch.Generator)
def test_get_greenlist_ids():
input_ids = [101, 2036, 3731, 102, 2003, 103]
processor = WatermarkLogitsProcessor()
result = processor._get_greenlist_ids(input_ids, 10, torch.device("cpu"))
assert max(result) <= 10
assert len(result) == int(10 * 0.5)
def test_calc_greenlist_mask():
processor = WatermarkLogitsProcessor()
scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]])
greenlist_token_ids = torch.tensor([2, 3])
result = processor._calc_greenlist_mask(scores, greenlist_token_ids)
assert result.tolist() == [[False, False, False, False], [False, False, True, True]]
assert result.shape == scores.shape
def test_bias_greenlist_logits():
processor = WatermarkLogitsProcessor()
scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]])
green_tokens_mask = torch.tensor(
[[False, False, True, True], [False, False, False, True]]
)
greenlist_bias = 2.0
result = processor._bias_greenlist_logits(scores, green_tokens_mask, greenlist_bias)
assert np.allclose(result.tolist(), [[0.5, 0.3, 2.2, 2.8], [0.1, 0.2, 0.7, 2.9]])
assert result.shape == scores.shape
def test_call():
input_ids = [101, 2036, 3731, 102, 2003, 103]
processor = WatermarkLogitsProcessor()
scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]])
result = processor(input_ids, scores)
assert result.shape == scores.shape

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,6 @@
import os
import psutil
import signal
import sys
import typer
@ -14,10 +16,15 @@ app = typer.Typer()
class Quantization(str, Enum):
bitsandbytes = "bitsandbytes"
bitsandbytes_nf4 = "bitsandbytes-nf4"
bitsandbytes_fp4 = "bitsandbytes-fp4"
gptq = "gptq"
awq = "awq"
eetq = "eetq"
exl2 = "exl2"
fp8 = "fp8"
compressed_tensors = "compressed-tensors"
marlin = "marlin"
class Dtype(str, Enum):
@ -25,11 +32,6 @@ class Dtype(str, Enum):
bloat16 = "bfloat16"
class KVCacheDtype(str, Enum):
fp8_e4m3fn = "fp8_e4m3fn"
fp8_e5m2 = "fp8_e5m2"
@app.command()
def serve(
model_id: str,
@ -38,7 +40,6 @@ def serve(
quantize: Optional[Quantization] = None,
speculate: Optional[int] = None,
dtype: Optional[Dtype] = None,
kv_cache_dtype: Optional[KVCacheDtype] = None,
trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO",
@ -98,34 +99,89 @@ def serve(
# Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value
dtype = "bfloat16" if dtype is None else dtype.value
kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value
logger.info(f"quantize={quantize} kv_cache_dtype={kv_cache_dtype}")
logger.info(f"quantize={quantize}")
if dtype is not None and quantize not in {
None,
"bitsandbytes",
"bitsandbytes-nf4",
"bitsandbytes-fp4",
"gptq",
"awq",
"fp8",
"compressed-tensors",
}:
raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)
server.serve(
model_id,
lora_adapters,
revision,
sharded,
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
uds_path,
max_input_tokens,
)
logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype))
if sharded:
tgi_file = Path(__file__).resolve().parent / "tgi_service.py"
num_shard = int(os.getenv("WORLD_SIZE", "1"))
logger.info("CLI SHARDED = {}".format(num_shard))
import subprocess
cmd = (
f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}"
)
cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}"
cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}"
cmd += f" --quantize {quantize} --max_input_tokens {max_input_tokens}"
if speculate is not None:
cmd += f"--speculate {speculate}"
logger.info("CLI server start deepspeed ={} ".format(cmd))
sys.stdout.flush()
sys.stderr.flush()
with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc:
do_terminate = False
current_handler = signal.getsignal(signal.SIGTERM)
def terminate_handler(sig, frame):
nonlocal do_terminate
do_terminate = True
if callable(current_handler):
current_handler(sig, frame)
signal.signal(signal.SIGTERM, terminate_handler)
finished = False
while not finished:
try:
if do_terminate:
parent = psutil.Process(proc.pid)
all_procs = parent.children(recursive=True) + [parent]
for p in all_procs:
try:
p.terminate()
except psutil.NoSuchProcess:
pass
_, alive = psutil.wait_procs(all_procs, timeout=30)
for p in alive:
p.kill()
do_terminate = False
proc.wait(timeout=3)
except subprocess.TimeoutExpired:
pass
else:
finished = True
sys.stdout.flush()
sys.stderr.flush()
if proc.returncode != 0:
logger.error(f"{cmd} exited with status = {proc.returncode}")
return proc.returncode
else:
server.serve(
model_id,
lora_adapters,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
uds_path,
max_input_tokens,
)
@app.command()

View File

@ -0,0 +1,53 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import os
import habana_frameworks.torch as htorch
quant_config = os.getenv("QUANT_CONFIG", "")
is_quantization_enabled = quant_config != ""
if is_quantization_enabled:
os.environ.setdefault("ENABLE_EXPERIMENTAL_FLAGS", "true")
os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true")
os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false")
os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false")
os.environ.setdefault("UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av")
os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
def patch_scoped_linear_all_reduce(model):
from deepspeed.module_inject.layers import LinearAllreduce
from optimum.habana.transformers.models.modeling_all_models import (
ScopedLinearAllReduce,
)
for name, module in model.named_children():
if type(module) is LinearAllreduce:
SL = ScopedLinearAllReduce(mod=module)
setattr(model, name, SL)
patch_scoped_linear_all_reduce(module)
def setup_quantization(model):
if is_quantization_enabled:
htorch.core.quantization._mark_params_as_const(model)
htorch.core.quantization._check_params_as_const(model)
htorch.core.hpu_initialize(model)
return model
def prepare_model_for_quantization(model):
if is_quantization_enabled:
if model.config.model_type in [
"llama",
"falcon",
"qwen2",
"starcoder2",
"gemma",
]:
patch_scoped_linear_all_reduce(model)
from neural_compressor.torch.quantization import FP8Config, convert
config = FP8Config.from_json_file(quant_config)
model = convert(model, config)
return model

View File

@ -12,7 +12,6 @@ from text_generation_server.layers.speculative import SpeculativeHead
# Just to add the `load` methods.
from text_generation_server.layers.layernorm import load_layer_norm
from text_generation_server.layers.conv import load_conv2d
from text_generation_server.layers.fp8 import Fp8Linear
from text_generation_server.layers.lora import (
LoraLinear,
@ -28,7 +27,6 @@ __all__ = [
"TensorParallelEmbedding",
"SpeculativeHead",
"LoraLinear",
"Fp8Linear",
"TensorParallelMultiAdapterLinear",
"TensorParallelAdapterRowLinear",
"load_layer_norm",

View File

@ -1,35 +1,43 @@
from .common import (
Seqlen,
HPUPagedAttentionMetadata,
trim_attn_metadata,
trim_seqlen_metadata,
_async_h2d_tensor_copy,
)
from text_generation_server.utils.import_utils import SYSTEM
import os
from .hpu import (
SUPPORTS_WINDOWING,
attention,
paged_attention,
paged_attention_mla,
set_block_mapping,
)
from .common import Seqlen
if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":
from .cuda import (
attention,
paged_attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
PREFILL_IN_KV_CACHE,
)
elif SYSTEM == "rocm":
from .rocm import (
attention,
paged_attention,
reshape_and_cache,
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
)
elif SYSTEM == "ipex":
from .ipex import (
attention,
paged_attention,
reshape_and_cache,
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
)
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
from .kv_cache import KVCache, get_kv_scales, KVCompressCache
__all__ = [
"attention",
"get_kv_scales",
"paged_attention",
"paged_attention_mla",
"set_block_mapping",
"reshape_and_cache",
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"KVCache",
"KVCompressCache",
"Seqlen",
"HPUPagedAttentionMetadata",
"trim_seqlen_metadata",
"trim_attn_metadata",
"_async_h2d_tensor_copy",
]

View File

@ -1,186 +1,72 @@
from dataclasses import dataclass
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
import torch
from typing import Optional, List, Dict
import collections
import torch.nn.functional as F
_TYPE_CACHE = {}
from typing import Optional
@dataclass
class HPUPagedAttentionMetadata:
"""Metadata for PagedAttention."""
if ATTENTION in {"flashinfer", "flashdecoding"}:
block_list: Optional[torch.Tensor]
block_mapping: Optional[torch.Tensor]
block_usage: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor]
attn_bias: Optional[torch.Tensor]
slots_in_window_mask: Optional[torch.Tensor] = None
block_list_in_window: Optional[torch.Tensor] = None
block_mapping_in_window: Optional[torch.Tensor] = None
block_usage_in_window: Optional[torch.Tensor] = None
block_groups_in_window: Optional[torch.Tensor] = None
attn_bias_in_window: Optional[torch.Tensor] = None
@dataclass
class Seqlen:
input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
max_q: int
max_k: int
def subtuple(
obj: object,
typename: str,
to_copy: List[str],
to_override: Optional[Dict[str, object]] = None,
):
if obj is None:
return None
if to_override is None:
to_override = {}
fields = set(to_copy) | set(to_override.keys())
if isinstance(obj, dict):
values = {key: obj[key] for key in fields if key in obj}
else:
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
if typename not in _TYPE_CACHE:
_TYPE_CACHE[typename] = collections.namedtuple(typename, " ".join(fields))
return _TYPE_CACHE[typename](**values)
def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
# NOTE(kzawora): To anyone working on this in the future:
# Trimming metadata is required when using HPUGraphs.
# Attention metadata is going to be hashed by PT bridge, and
# appropriate HPUGraphs will be matched based on all inputs' hash.
# Before you put more keys in here, make sure you know their
# value type and make sure you know how it's going to be hashed.
# You can find that information in input_hash function
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
# If you use primitive types here - they will get hashed based
# on their value. You *will* get lots of excessive graph captures
# (and an OOM eventually) if you decide to put something like
# seq_len int here.
# If you absolutely need a scalar, put it in a tensor. Tensors
# get hashed using their metadata, not their values:
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
# input_hash(123) != input_hash(321)
# input_hash("abc") != input_hash("cba")
attention_metadata = subtuple(
metadata,
"TrimmedAttentionMetadata",
[
"block_list",
"block_mapping",
"block_usage",
"block_groups",
"attn_bias",
"slots_in_window_mask",
"block_list_in_window",
"block_mapping_in_window",
"block_usage_in_window",
"block_groups_in_window",
"attn_bias_in_window",
],
)
return attention_metadata
@dataclass
class Seqlen:
input_lengths: torch.Tensor
attn_mask: Optional[torch.Tensor] = None
def __init__(
self,
input_lengths,
):
self.input_lengths = input_lengths
def clamp(self, max):
# Flash decoding doesn't need to clamp
return self
def make_sliding_window_bias(
self,
seq_lens: List[int],
window_size: Optional[int],
dtype: torch.dtype,
padded_input_len: Optional[int],
padded_bs: Optional[int],
) -> List[torch.Tensor]:
attn_biases = []
for seq_len in seq_lens:
if seq_len != 0:
tensor = torch.full(
(1, seq_len, seq_len),
dtype=dtype,
fill_value=1,
)
shift = 0
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
if window_size is not None:
mask = torch.triu(mask, diagonal=shift - window_size + 1)
mask = F.pad(
mask,
(
padded_input_len - seq_len,
0,
padded_input_len - seq_len,
0,
0,
0,
),
value=0,
def __init__(
self,
input_lengths,
prefix_lengths,
cu_seqlen_q=None,
max_q=None,
max_k=None,
):
self.input_lengths = input_lengths
self.prefix_lengths = prefix_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
if cu_seqlen_q is None:
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
max_q = 1
else:
mask = torch.full(
(1, padded_input_len, padded_input_len),
dtype=dtype,
fill_value=0,
)
attn_biases.append(mask)
attn_biases = torch.stack(attn_biases, dim=0)
return attn_biases.to(torch.bool)
assert max_q is not None
assert max_k is not None
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
total = self.input_lengths + self.prefix_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
def _async_h2d_tensor_copy(source, device="hpu"):
if source is None:
return None
if source.device.type == "hpu":
return source
assert source.device.type == "cpu", "Source tensor is not present in host memory!"
target = torch.empty(source.shape, dtype=source.dtype, device=device)
target.copy_(source, non_blocking=True)
return target
self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
self.max_q = max_q
self.max_k = max_k
def clamp(self, max):
# Flash decoding doesn't need to clamp
return self
def trim_seqlen_metadata(metadata: Seqlen) -> object:
# NOTE(kzawora): To anyone working on this in the future:
# Trimming metadata is required when using HPUGraphs.
# Attention metadata is going to be hashed by PT bridge, and
# appropriate HPUGraphs will be matched based on all inputs' hash.
else:
# Before you put more keys in here, make sure you know their
# value type and make sure you know how it's going to be hashed.
# You can find that information in input_hash function
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
@dataclass
class Seqlen:
input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: torch.Tensor
max_q: int
max_k: int
# If you use primitive types here - they will get hashed based
# on their value. You *will* get lots of excessive graph captures
# (and an OOM eventually) if you decide to put something like
# seq_len int here.
# If you absolutely need a scalar, put it in a tensor. Tensors
# get hashed using their metadata, not their values:
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
# input_hash(123) != input_hash(321)
# input_hash("abc") != input_hash("cba")
attention_metadata = subtuple(
metadata,
"TrimmedSeqlen",
[
"input_lengths",
"attn_mask",
],
)
return attention_metadata
def clamp(self, max):
if SYSTEM == "rocm":
return self
raise NotImplementedError("Not implemented seqlen for paged")
return Seqlen(torch.clamp(self.input_lengths, max=max))

View File

@ -0,0 +1,357 @@
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import (
ATTENTION,
BLOCK_SIZE,
)
from text_generation_server.layers.attention import Seqlen
from typing import Optional
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512
try:
from vllm._C import cache_ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if ATTENTION in {"flashdecoding", "flashinfer"}:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
softcap: Optional[float] = None,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# value_cache => [num_blocks, num_heads, head_size, block_size]
# block_size = value_cache.shape[3]
block_size = BLOCK_SIZE
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import decode_state
return decode_state.get().forward(
query.contiguous(),
paged_kv_cache=(key_cache, value_cache),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
)
elif ATTENTION == "flashdecoding":
max_q = 1
max_k = max_s
import flash_attn_2_cuda
# TODO fixme when flash contains the fix.
# Number of splits is not correctly handled
# by the current path
# https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
if softcap is None:
softcap = 0.0
out = flash_attn_2_cuda.varlen_fwd(
query,
key_cache,
value_cache,
None,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None, # pad_k
None,
block_tables,
None,
max_q,
max_k,
0.0, # dropout
softmax_scale,
False, # zero_tensors
True, # causal
-1, # Window_left
-1, # Window right
softcap,
False, # return softmax
None, # generator
)
return out[0]
else:
if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping")
input_lengths = seqlen.input_lengths
from vllm._C import ops
out = torch.empty_like(query)
use_v1 = max_s <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512
)
if use_v1:
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
return out
try:
is_ampere_or_newer = major >= 8 and minor >= 0
if not is_ampere_or_newer:
raise ImportError("FlashAttention only supports Ampere GPUs or newer.")
import flash_attn_2_cuda
V2 = True
except ImportError:
try:
import flash_attn_cuda
V2 = False
except ImportError as e:
if major >= 8:
architecture_suffix = f"-{SYSTEM}"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
elif is_sm75:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
else:
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
SUPPORTS_WINDOWING = V2
if ATTENTION == "flashinfer":
def attention(
q: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
from text_generation_server.layers.attention.flashinfer import (
prefill_with_paged_kv_state,
)
return prefill_with_paged_kv_state.get().forward(
q.contiguous(),
causal=causal,
paged_kv_cache=(key_cache, value_cache),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
window_left=window_size_left,
)
elif V2:
def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
out = torch.empty_like(q)
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
None,
block_tables,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
softcap,
False,
None,
)[0]
else:
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap=None,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
if softcap is not None:
raise NotImplementedError("softcap is only available with flash attn v2")
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
out = torch.empty_like(q)
flash_attn_cuda.fwd(
q,
k,
v,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
False,
0,
None,
)
return out
# Prefill in the cache with every kind of attention, unless we
# have a configuration that requires flash-attention v1, which
# does not support block tables.
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2

View File

@ -0,0 +1,813 @@
#!/usr/bin/env python
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
(https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
Features supported:
1) Fwd with causal masking
2) Any sequence lengths without padding (currently fwd kernel only)
3) Support for different sequence lengths for q and k
4) Nested tensor API currently does not support dropout or bias.
Not currently supported:
1) Non power of two head dims
"""
import torch
import triton
import triton.language as tl
torch_dtype: tl.constexpr = torch.float16
@triton.jit
def cdiv_fn(x, y):
return (x + y - 1) // y
@triton.jit
def max_fn(x, y):
return tl.math.max(x, y)
@triton.jit
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
ms = tl.arange(0, m)
ns = tl.arange(0, n)
return philox_offset + ms[:, None] * stride + ns[None, :]
@triton.jit
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_offsets = dropout_offsets(
philox_seed, philox_offset, dropout_p, m, n, stride
).to(tl.uint32)
# TODO: use tl.randint for better performance
return tl.rand(philox_seed, rng_offsets)
@triton.jit
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)
rng_keep = rng_output > dropout_p
return rng_keep
@triton.jit
def load_fn(block_ptr, first, second, pad):
if first and second:
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
elif first:
tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad)
elif second:
tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad)
else:
tensor = tl.load(block_ptr)
return tensor
@triton.jit
def _attn_fwd_inner(
acc,
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
start_m,
actual_seqlen_k,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
block_min,
block_max,
offs_n_causal,
masked_blocks,
n_extra_tokens,
bias_ptr,
IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr,
OFFS_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr,
MASK_STEPS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr,
PADDED_HEAD: tl.constexpr,
):
# loop over k, v, and update accumulator
for start_n in range(block_min, block_max, BLOCK_N):
# For padded blocks, we will overrun the tensor size if
# we load all BLOCK_N. For others, the blocks are all within range.
k = load_fn(
K_block_ptr,
PADDED_HEAD,
MASK_STEPS and (n_extra_tokens != 0),
"zero",
)
if PRE_LOAD_V:
v = load_fn(
V_block_ptr,
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# We start from end of seqlen_k so only the first iteration would need
# to be checked for padding if it is not a multiple of block_n
# TODO: This can be optimized to only be true for the padded block.
if MASK_STEPS: # noqa: SIM102
# If this is the last block / iteration, we want to
# mask if the sequence length is not a multiple of block size
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps
# if not is_modulo_mn. last step might get wasted but that is okay.
# check if this masking works for that case.
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)
size_n = start_n + OFFS_N[None, :]
mask = size_n < boundary_m[:, None]
qk = tl.where(mask, qk, float("-inf"))
if IS_CAUSAL:
causal_boundary = start_n + offs_n_causal
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
qk = tl.where(causal_mask, qk, float("-inf"))
# -- compute qk ----
qk += tl.dot(q, k)
if bias_ptr is not None:
bias = load_fn(
bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero"
)
# While bias is added after multiplying qk with sm_scale, our
# optimization to use 2^x instead of e^x results in an additional
# scale factor of log2(e) which we must also multiply the bias with.
qk += bias * 1.44269504089
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk = qk - m_ij[:, None]
p = tl.math.exp2(qk)
# CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1)
if ENABLE_DROPOUT:
philox_offset = (
batch_philox_offset
+ start_m * BLOCK_M * actual_seqlen_k
+ start_n
- BLOCK_N
)
keep = dropout_mask(
philox_seed,
philox_offset,
dropout_p,
BLOCK_M,
BLOCK_N,
actual_seqlen_k,
)
if RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty),
)
p = tl.where(keep, p, 0.0)
elif RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
p.to(encoded_softmax_block_ptr.type.element_ty),
)
# -- update output accumulator --
alpha = tl.math.exp2(m_i - m_ij)
acc = acc * alpha[:, None]
if not PRE_LOAD_V:
v = load_fn(
V_block_ptr,
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
# -- update m_i and l_i
l_i = l_i * alpha + l_ij
# update m_i and l_i
m_i = m_ij
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
if bias_ptr is not None:
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.advance(
encoded_softmax_block_ptr, (0, BLOCK_N)
)
return acc, l_i, m_i
@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 64,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": True,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 64,
"BLOCK_N": 64,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 32,
"BLOCK_N": 32,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
# TODO: This config fails with head_size not pow2 with data mismatches.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config(
{
"BLOCK_M": 16,
"BLOCK_N": 16,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
],
key=["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"],
)
@triton.jit
def attn_fwd(
Q,
K,
V,
bias,
sm_scale,
L,
Out,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
stride_oz,
stride_oh,
stride_om,
stride_on,
stride_bz,
stride_bh,
stride_bm,
stride_bn,
cu_seqlens_q,
cu_seqlens_k,
dropout_p,
philox_seed,
philox_offset_base,
encoded_softmax,
HQ: tl.constexpr,
HK: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr,
VARLEN: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr,
BIAS_TYPE: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr,
):
start_m = tl.program_id(0)
off_h_q = tl.program_id(1)
off_z = tl.program_id(2)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
if VARLEN:
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
# small for all start_m so for those we return early.
if start_m * BLOCK_M > seqlen_q:
return
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
else:
cu_seqlens_q_start = 0
cu_seqlens_k_start = 0
seqlen_q = MAX_SEQLENS_Q
seqlen_k = MAX_SEQLENS_K
# Now we compute whether we need to exit early due to causal masking.
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
# are completely masked, resulting in 0s written to the output, and
# inf written to LSE. We don't need to do any GEMMs in this case.
# This block of code determines what N is, and if this WG is operating
# on those M rows.
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
if IS_CAUSAL:
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
# If seqlen_q != seqlen_k, attn scores are rectangular which means
# the causal mask boundary is bottom right aligned, and ends at either
# the top edge (seqlen_q < seqlen_k) or left edge.
# This captures the decrease in n_blocks if we have a rectangular attn
# matrix
n_blocks_seqlen = cdiv_fn(
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N
)
# This is what adjusts the block_max for the current WG, only
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
n_blocks = min(n_blocks, n_blocks_seqlen)
# If we have no blocks after adjusting for seqlen deltas, this WG is
# part of the blocks that are all 0. We exit early.
if n_blocks <= 0:
o_offset = (
off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
)
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(seqlen_q, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
# We still need to write 0s to the result
# tl.store(O_block_ptr,
# acc.to(Out.type.element_ty), boundary_check=(0,1))
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# + offs_m
# We store inf to LSE, not -inf because in the bwd pass,
# we subtract this
# from qk which makes it -inf, such that exp(qk - inf) = 0
# for these masked blocks.
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
# tl.store(l_ptrs, l)
# TODO: Should dropout and return encoded softmax be handled here?
return
# If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE: tl.constexpr = HQ // HK
if GROUP_SIZE != 1:
off_h_k = off_h_q // GROUP_SIZE
else:
off_h_k = off_h_q
n_extra_tokens = 0
if seqlen_k < BLOCK_N:
n_extra_tokens = BLOCK_N - seqlen_k
elif seqlen_k % BLOCK_N:
n_extra_tokens = seqlen_k % BLOCK_N
PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
# Compute pointers for all the tensors used in this kernel.
q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1),
)
v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)
if BIAS_TYPE != 0:
bias_ptr = tl.make_block_ptr(
base=bias + off_h_q * stride_bh,
shape=(seqlen_q, seqlen_k),
strides=(stride_bm, stride_bn),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
else:
bias_ptr = None
if ENABLE_DROPOUT:
batch_philox_offset = (
philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k
)
else:
batch_philox_offset = 0
# We can ask to return the dropout mask without actually doing any dropout.
# In this case, we return an invalid pointer so indicate the mask is not i
# valid.
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.make_block_ptr(
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
shape=(seqlen_q, seqlen_k),
strides=(seqlen_k, 1),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
else:
encoded_softmax_block_ptr = 0
# initialize pointer to m and l
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
# have native e^x support in HW.
qk_scale = sm_scale * 1.44269504089
# Q is loaded once at the beginning and shared by all N blocks.
q = load_fn(Q_block_ptr, True, PADDED_HEAD, "zero")
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
# Here we compute how many full and masked blocks we have.
padded_block_k = n_extra_tokens != 0
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
if IS_CAUSAL:
# There are always at least BLOCK_M // BLOCK_N masked blocks.
# Additionally there might be one more due to dissimilar seqlens.
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
else:
# Padding on Q does not need to be masked in the FA loop.
masked_blocks = padded_block_k
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional
# block. In this case we might exceed n_blocks so pick the min.
masked_blocks = min(masked_blocks, n_blocks)
n_full_blocks = n_blocks - masked_blocks
block_min = 0
block_max = n_blocks * BLOCK_N
# Compute for full blocks. Here we set causal to false regardless of its
# value because there is no masking. Similarly we do not need padding.
if n_full_blocks > 0:
block_max = (n_blocks - masked_blocks) * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
start_m,
seqlen_k,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min,
block_max,
0,
0,
0,
bias_ptr,
# IS_CAUSAL, ....
False,
BLOCK_M,
BLOCK_DMODEL,
BLOCK_N,
offs_m,
offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V,
False,
ENABLE_DROPOUT,
RETURN_ENCODED_SOFTMAX,
PADDED_HEAD,
)
block_min = block_max
block_max = n_blocks * BLOCK_N
tl.debug_barrier()
# Remaining blocks, if any, are full / not masked.
if masked_blocks > 0:
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
if bias_ptr is not None:
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.advance(
encoded_softmax_block_ptr, (0, n_full_blocks)
)
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
start_m,
seqlen_k,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
block_min,
block_max,
offs_n_causal,
masked_blocks,
n_extra_tokens,
bias_ptr,
IS_CAUSAL,
BLOCK_M,
BLOCK_DMODEL,
BLOCK_N,
offs_m,
offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V,
True,
ENABLE_DROPOUT,
RETURN_ENCODED_SOFTMAX,
PADDED_HEAD,
)
# epilogue
acc = acc / l_i[:, None]
if ENABLE_DROPOUT:
acc = acc / (1 - dropout_p)
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
# then we have one block with a row of all NaNs which come from computing
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
# and store 0s where there are NaNs as these rows should've been zeroed out.
end_m_idx = (start_m + 1) * BLOCK_M
start_m_idx = start_m * BLOCK_M
causal_start_idx = seqlen_q - seqlen_k
acc = acc.to(Out.type.element_ty)
if IS_CAUSAL: # noqa: SIM102
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
out_mask_boundary = tl.full(
(BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32
)
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :]
z = 0.0
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
# few rows. This is only true for the last M block. For others,
# overflow_size will be -ve
# overflow_size = end_m_idx - seqlen_q
# if overflow_size > 0:
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
# # This is a > check because mask being 0 blocks the store.
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
# else:
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
# Need boundary check on this to make sure the padding from the
# Q and KV tensors in both dims are not part of what we store back.
# TODO: Do the boundary check optionally.
tl.store(O_block_ptr, acc, boundary_check=(0, 1))
def check_args(
q,
k,
v,
o,
varlen=True,
max_seqlens=None,
cu_seqlens_q=None,
cu_seqlens_k=None,
):
assert q.dim() == k.dim() and q.dim() == v.dim()
if varlen:
assert q.dim() == 3
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
assert cu_seqlens_q is not None
assert cu_seqlens_k is not None
assert len(cu_seqlens_q) == len(cu_seqlens_k)
else:
assert q.dim() == 4
batch, nheads_q, seqlen_q, head_size = q.shape
_, nheads_k, seqlen_k, _ = k.shape
assert max_seqlens > 0
assert k.shape == v.shape
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
# TODO: Change assert if we support qkl f8 and v f16
assert q.dtype == k.dtype and q.dtype == v.dtype
# TODO: Fix assert to check head size <=256 once supported
assert head_size <= 128
assert o.shape == q.shape
assert (nheads_q % nheads_k) == 0
class _attention(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
o,
cu_seqlens_q,
cu_seqlens_k,
max_seqlens_q,
max_seqlens_k,
causal=False,
sm_scale=1.0,
bias=None,
):
if o is None:
o = torch.empty_like(q, dtype=v.dtype)
check_args(
q,
k,
v,
o,
varlen=True,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
)
if True: # varlen
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
batch = len(cu_seqlens_q) - 1
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
else:
batch, seqlen_q, nheads_q, head_size = q.shape
_, seqlen_k, nheads_k, _ = k.shape
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
# Get closest power of 2 over or equal to 32.
padded_d_model = 1 << (head_size - 1).bit_length()
padded_d_model = max(padded_d_model, 16)
def grid(META):
return triton.cdiv(max_seqlens_q, META["BLOCK_M"]), nheads_q, batch
encoded_softmax = None
# Seed the RNG so we get reproducible results for testing.
philox_seed = 0x1BF52
philox_offset = 0x1D4B42
if bias is not None:
bias_strides = (
bias.stride(0),
bias.stride(1),
bias.stride(2),
bias.stride(3),
)
else:
bias_strides = (0, 0, 0, 0)
attn_fwd[grid](
q,
k,
v,
bias,
sm_scale,
None,
o,
*q_strides,
*k_strides,
*v_strides,
*o_strides,
*bias_strides,
cu_seqlens_q,
cu_seqlens_k,
dropout_p=0.0,
philox_seed=philox_seed,
philox_offset_base=philox_offset,
encoded_softmax=encoded_softmax,
HQ=nheads_q,
HK=nheads_k,
ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k,
IS_CAUSAL=causal,
VARLEN=True,
BLOCK_DMODEL=padded_d_model,
BIAS_TYPE=0 if bias is None else 1,
ENABLE_DROPOUT=False,
RETURN_ENCODED_SOFTMAX=False,
)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = head_size
ctx.causal = causal
ctx.dropout_p = 0.0
ctx.philox_seed = philox_seed
ctx.philox_offset = philox_offset
ctx.encoded_softmax = encoded_softmax
ctx.return_encoded_softmax = False
return o, encoded_softmax
triton_attention = _attention.apply

View File

@ -0,0 +1,251 @@
from typing import Optional
from contextvars import ContextVar
from contextlib import contextmanager
import flashinfer
import torch
prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar(
"prefill_state"
)
prefill_with_paged_kv_state: ContextVar[
flashinfer.BatchPrefillWithPagedKVCacheWrapper
] = ContextVar("prefill_with_paged_kv_state")
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
"decode_state"
)
workspace: Optional[torch.Tensor] = None
def get_workspace(device):
"""Get shared flashinfer workspace."""
global workspace
if workspace is None:
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
return workspace
def create_prefill_with_paged_kv_state(
*,
device: torch.device,
):
"""Create a prefill state that uses the KV cache."""
workspace_buffer = get_workspace(device)
return flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
)
@contextmanager
def use_prefill_with_paged_kv_state(
*,
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
block_tables: torch.Tensor,
cu_seqlens: torch.Tensor,
input_lengths: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
page_size: int,
dtype: torch.dtype,
window_left: int,
):
"""
Context manager to set the active flashinfer prefill state to the given
`state` and parameters. This state will be used by all calls to the
`attention` function while the context manager is active.
"""
indptr = torch.zeros(
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
)
# Round up to page size and then calculate the cumulative sum to get
# the indices into the block table.
torch.add(input_lengths, page_size - 1, out=indptr[1:])
indptr[1:].div_(page_size, rounding_mode="floor")
indptr[1:].cumsum_(-1)
# Get the lengths of the last page in a block.
if page_size == 1:
last_page_len = torch.ones(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
else:
last_page_len = torch.empty(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
torch.sub(input_lengths, 1, out=last_page_len)
last_page_len.remainder_(page_size)
last_page_len += 1
token = prefill_with_paged_kv_state.set(state)
try:
state.begin_forward(
qo_indptr=cu_seqlens,
paged_kv_indptr=indptr,
paged_kv_indices=block_tables,
paged_kv_last_page_len=last_page_len,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=dtype,
page_size=page_size,
window_left=window_left,
)
yield
finally:
state.end_forward()
if token is not None:
prefill_with_paged_kv_state.reset(token)
def create_prefill_state(
*,
device: torch.device,
):
"""Create a prefill state."""
workspace_buffer = get_workspace(device)
return flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
)
@contextmanager
def use_prefill_state(
*,
state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper,
cu_seqlens: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
window_left: int,
):
"""
Context manager to set the active flashinfer prefill state to the given
`state` and parameters. This state will be used by all calls to the
`attention` function while the context manager is active.
"""
token = prefill_state.set(state)
try:
state.begin_forward(
qo_indptr=cu_seqlens,
kv_indptr=cu_seqlens,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=dtype,
window_left=window_left,
)
yield
finally:
state.end_forward()
if token is not None:
prefill_state.reset(token)
def create_decode_state(
*,
device: torch.device,
num_heads: int,
num_kv_heads: int,
):
"""Create a decode state."""
workspace_buffer = get_workspace(device)
num_groups = num_heads // num_kv_heads
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
use_cuda_graph=False,
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
use_tensor_cores=num_groups not in [1, 2, 4, 8],
)
def create_decode_state_cuda_graphs(
*,
device: torch.device,
block_tables: torch.Tensor,
block_tables_ptr: torch.Tensor,
last_page_len: torch.Tensor,
num_heads: int,
num_kv_heads: int,
):
"""
Create a decode state for use with CUDA Graphs. `block_tables`,
`block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are
therefore stored as part of the state.
"""
workspace_buffer = get_workspace(device)
num_groups = num_heads // num_kv_heads
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
use_cuda_graph=True,
paged_kv_indices_buffer=block_tables,
paged_kv_indptr_buffer=block_tables_ptr,
paged_kv_last_page_len_buffer=last_page_len,
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
use_tensor_cores=num_groups not in [1, 2, 4, 8],
)
@contextmanager
def use_decode_state(
*,
state: flashinfer.BatchDecodeWithPagedKVCacheWrapper,
input_lengths: torch.Tensor,
block_tables: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
page_size: int,
dtype: torch.dtype,
window_left: int,
):
"""
Context manager to set the active flashinfer decoding state to the given
`state` and parameters. This state will be used by all calls to the
`paged_attention` function while the context manager is active.
"""
indptr = torch.zeros(
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
)
# Round up to page size and then calculate the cumulative sum to get
# the indices into the block table.
torch.add(input_lengths, page_size - 1, out=indptr[1:])
indptr[1:].div_(page_size, rounding_mode="floor")
indptr[1:].cumsum_(-1)
# Get the lengths of the last page in a block.
last_page_len = torch.empty(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
torch.sub(input_lengths, 1, out=last_page_len)
last_page_len.remainder_(page_size)
last_page_len += 1
token = decode_state.set(state)
try:
state.begin_forward(
indptr=indptr,
indices=block_tables,
last_page_len=last_page_len,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
page_size=page_size,
data_type=dtype,
q_data_type=dtype,
window_left=window_left,
)
yield
finally:
state.end_forward()
if token is not None:
decode_state.reset(token)

View File

@ -1,227 +0,0 @@
import torch
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
from typing import Optional
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from vllm_hpu_extension import ops
from vllm_hpu_extension.utils import Matmul
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA
import os
from text_generation_server.models.globals import BLOCK_SIZE
import math
SUPPORTS_WINDOWING = False
class FP8Matmul(torch.nn.Module):
def __init__(self, scale_other):
super().__init__()
self.scale_input = torch.tensor(1.0, dtype=torch.bfloat16, device="hpu")
self.scale_other = scale_other
def quant_input(self, x, scale):
return torch.ops.hpu.cast_to_fp8_v2(
x, scale, False, False, torch.float8_e4m3fn
)[0]
def matmul_fp8(
self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None
):
return torch.ops.hpu.fp8_gemm_v2(
A=x,
trans_A=False,
B=other,
trans_B=False,
D=None,
out_dtype=out_dtype,
A_scale_inv=scale_input_inv,
B_scale_inv=scale_other_inv,
bias=None,
accumulate=False,
)
def forward(self, input, other):
qinput = self.quant_input(input, self.scale_input)
qother = self.quant_input(other, self.scale_other)
output = self.matmul_fp8(
qinput,
qother,
out_dtype=torch.bfloat16,
scale_input_inv=1.0 / self.scale_input,
scale_other_inv=1.0 / self.scale_other,
)
return output
class FetchFromCache(torch.nn.Module):
def __init__(self, scale_inv):
super().__init__()
self.scale_inv = scale_inv
def forward(self, cache, blocks):
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
out = cache[: blocks.size(0)]
else:
out = cache.index_select(0, blocks)
if out.dtype == torch.float8_e4m3fn:
out = torch.ops.hpu.cast_from_fp8(out, self.scale_inv, torch.bfloat16)
return out
def attention(
*,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
):
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
bs = seqlen.input_lengths.shape[0]
_, head_num, head_size = query.shape
_, kv_head_num, head_size = key.shape
query = query.view(bs, -1, head_num, head_size).transpose(1, 2)
key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
attn_output = fsdpa_op(
query,
key,
value,
attn_mask=seqlen.attn_mask if window_size_left != -1 else None,
dropout_p=0.0,
is_causal=causal if window_size_left == -1 else False,
scale=softmax_scale,
softmax_mode="None",
recompute_mode=None,
valid_sequence_lengths=seqlen.input_lengths if window_size_left == -1 else None,
padding_side="left",
)
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
return attn_output
def set_block_mapping(hpu_attention_meta: HPUPagedAttentionMetadata, batch_size):
block_mapping = torch.nn.functional.one_hot(
hpu_attention_meta.block_groups, num_classes=batch_size
)
dtype = hpu_attention_meta.block_usage.dtype
device = hpu_attention_meta.block_usage.device
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
mask = mask >= hpu_attention_meta.block_usage.unsqueeze(-1)
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
hpu_attention_meta = hpu_attention_meta._replace(
attn_bias=attn_bias, block_mapping=block_mapping.to(dtype)
)
if hpu_attention_meta.block_groups_in_window is not None:
block_mapping = torch.nn.functional.one_hot(
hpu_attention_meta.block_groups_in_window, num_classes=batch_size
)
attn_bias = torch.log(hpu_attention_meta.slots_in_window_mask.float())
hpu_attention_meta = hpu_attention_meta._replace(
attn_bias_in_window=attn_bias,
block_mapping_in_window=block_mapping.to(dtype),
)
return hpu_attention_meta
def paged_attention(
query: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
seqlen: Seqlen,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
hpu_attention_meta: HPUPagedAttentionMetadata,
window_size_left: int = -1,
):
batch_size, head_num, head_size = query.shape
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
output = ops.flat_pa(
query=query.view(batch_size, 1, head_num * head_size),
key_cache=kv_cache.key,
value_cache=kv_cache.value,
block_list=(
hpu_attention_meta.block_list
if window_size_left == -1
else hpu_attention_meta.block_list_in_window
),
block_mapping=(
hpu_attention_meta.block_mapping
if window_size_left == -1
else hpu_attention_meta.block_mapping_in_window
),
block_bias=(
hpu_attention_meta.attn_bias
if window_size_left == -1
else hpu_attention_meta.attn_bias_in_window
),
block_groups=(
hpu_attention_meta.block_groups
if window_size_left == -1
else hpu_attention_meta.block_groups_in_window
),
block_size=BLOCK_SIZE,
scale=softmax_scale,
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
batch2block_matmul_op=Matmul(),
block2batch_matmul_op=Matmul(),
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
values_fetch_func=FetchFromCache(1.0 / kv_scales.value_scale_cpu),
)
# Reshape the output tensor.
return output.view(batch_size, head_num, head_size)
def paged_attention_mla(
query: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
seqlen: Seqlen,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
hpu_attention_meta: HPUPagedAttentionMetadata,
kv_lora_rank: int = 0,
):
batch_size, head_num, head_size = query.shape
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
output = ops.flat_pa_mla(
query=query,
key_cache=kv_cache.key,
value_cache=None,
block_list=hpu_attention_meta.block_list,
block_mapping=hpu_attention_meta.block_mapping,
block_bias=hpu_attention_meta.attn_bias,
block_groups=hpu_attention_meta.block_groups,
block_size=BLOCK_SIZE,
scale=softmax_scale,
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
batch2block_matmul_op=Matmul(),
block2batch_matmul_op=Matmul(),
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
values_fetch_func=None,
kv_lora_rank=kv_lora_rank,
)
# Reshape the output tensor.
return output.view(batch_size, head_num, -1)
__all__ = [
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"paged_attention_mla",
"set_block_mapping",
]

View File

@ -0,0 +1,82 @@
import intel_extension_for_pytorch as ipex
import torch
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen
from typing import Optional
SUPPORTS_WINDOWING = False
PREFILL_IN_KV_CACHE = False
def attention(
q: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap: Optional[float] = None,
):
out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
ipex.llm.functional.varlen_attention(
q.contiguous() if q.device.type == "xpu" else q,
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_q,
0.0,
softmax_scale,
False,
causal,
False,
None,
)
return out
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
softcap: Optional[float] = None,
):
out = torch.empty_like(query)
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
seqlen.input_lengths,
BLOCK_SIZE,
max_s,
None,
)
return out

View File

@ -1,205 +0,0 @@
from typing import Tuple
from dataclasses import dataclass, field
import torch
from text_generation_server.models.globals import BLOCK_SIZE
from text_generation_server.utils.weights import Weights
@dataclass
class KVScales:
"""
Key-value scales for FP8 KV cache.
This data class stores key and value scales both as a GPU tensor and
as a GPU float. This inconvenience is necessary because some functions
(e.g. scaling kernels) take scales as a GPU tensor, whereas others
(e.g. flashinfer) take scales as a CPU scalar.
"""
key_scale: torch.Tensor
value_scale: torch.Tensor
key_scale_cpu: float = field(init=False)
value_scale_cpu: float = field(init=False)
def __post_init__(self):
if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:
raise ValueError("Key and value scales must be scalar tensors.")
self.key_scale_cpu = self.key_scale.item()
self.value_scale_cpu = self.value_scale.item()
class KVCache:
"""
Key-value cache for attention layers.
"""
kv_cache: Tuple[torch.Tensor, torch.Tensor]
def __init__(
self,
*,
num_blocks: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
"""Construct the key-value cache for a layer."""
## TODO FP8 kv cache support
if dtype is torch.float8_e5m2:
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
self.kv_cache = (
torch.zeros(
(num_blocks * BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
torch.zeros(
(num_blocks * BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
)
@property
def dtype(self):
"""Get the data type of the cache."""
return self.kv_cache[0].dtype
@property
def key(self):
"""Get the key cache."""
return self.kv_cache[0]
@property
def value(self):
"""Get the value cache."""
return self.kv_cache[1]
def store(
self,
*,
key: torch.Tensor,
value: torch.Tensor,
slots: torch.Tensor,
kv_scales: KVScales,
):
"""Store the key and value at the given slots."""
## TODO FP8 kv cache support
key_cache = self.kv_cache[0]
value_cache = self.kv_cache[1]
paged_reshape_and_cache(
key,
value,
key_cache,
value_cache,
slots,
kv_scales.key_scale,
kv_scales.value_scale,
)
class KVCompressCache(KVCache):
"""
Key-value cache for attention layers.
"""
kv_cache: torch.Tensor
def __init__(
self,
*,
num_blocks: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
"""Construct the key-value cache for a layer."""
## TODO FP8 kv cache support
if dtype is torch.float8_e5m2:
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
self.kv_cache = torch.zeros(
(num_blocks * BLOCK_SIZE, 1, head_size),
dtype=dtype,
device=device,
)
@property
def dtype(self):
"""Get the data type of the cache."""
return self.kv_cache.dtype
@property
def key(self):
"""Get the key cache."""
return self.kv_cache
@property
def value(self):
"""Get the value cache."""
return self.kv_cache
def store(
self,
*,
key: torch.Tensor,
value: torch.Tensor,
slots: torch.Tensor,
kv_scales: KVScales,
):
"""Store the key and value at the given slots."""
## TODO FP8 kv cache support
if self.kv_cache.dtype == torch.float8_e4m3fn:
key = torch.ops.hpu.cast_to_fp8_v2(
key, kv_scales.key_scale, False, False, torch.float8_e4m3fn
)[0]
self.kv_cache.index_copy_(0, slots, key)
def paged_reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
):
if key_cache.dtype == torch.float8_e4m3fn:
key = torch.ops.hpu.cast_to_fp8_v2(
key, k_scale, False, False, torch.float8_e4m3fn
)[0]
value = torch.ops.hpu.cast_to_fp8_v2(
value, v_scale, False, False, torch.float8_e4m3fn
)[0]
key_cache.index_copy_(0, slots, key)
value_cache.index_copy_(0, slots, value)
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
"""Load KV cache scales."""
key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)
value_scale = key_scale
if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor(
f"{prefix}.v_scale"
):
key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float()
value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float()
elif weights.has_tensor(f"{prefix}.kv_scale"):
# Fall back to older more coarse-grained scale when available.
key_scale = weights.get_tensor(f"{prefix}.kv_scale").float()
value_scale = key_scale
return KVScales(key_scale=key_scale, value_scale=value_scale)

View File

@ -0,0 +1,308 @@
import os
from typing import Optional
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils.log import log_master
from loguru import logger
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE_V1V2 = 512
_PARTITION_SIZE_CUSTOM = 256
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck"
PREFILL_IN_KV_CACHE = False
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
try:
if use_rocm_custom_paged_attn:
from vllm._custom_C import paged_attention_custom
except ImportError as e:
log_master(
logger.info,
f"Custom Paged Attention not available. Complete error: {e}",
)
use_rocm_custom_paged_attn = False
try:
import vllm._custom_ops as ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if ATTENTION == "flashdecoding":
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
softcap: Optional[float] = None,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping")
# value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
num_kv_heads = key_cache.shape[1]
gqa_ratio = num_heads // num_kv_heads
use_custom = (
use_rocm_custom_paged_attn
and (query.dtype == torch.half or query.dtype == torch.bfloat16)
and (head_size == 128 or head_size == 64)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_s <= 32768
)
if not use_custom:
_PARTITION_SIZE = _PARTITION_SIZE_V1V2
else:
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = seqlen.input_lengths
out = torch.empty_like(query)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
import vllm._custom_ops as ops
use_v1 = (
max_s <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512)
and not use_custom
)
if use_v1:
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
if not use_custom:
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
else:
paged_attention_custom(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
)
return out
if ENGINE != "triton":
try:
import flash_attn_2_cuda
log_master(
logger.info,
"ROCm: using Flash Attention 2 Composable Kernel implementation.",
)
except ImportError as e:
if major >= 8:
architecture_suffix = f"-{SYSTEM}"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
elif is_sm75:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
else:
for idx in range(torch.cuda.device_count()):
name = torch.cuda.get_device_name(idx)
if "MI210" not in name and "MI250" not in name:
raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
)
raise ImportError(
f"AMD GPU with ROCm capability {major} {minor} is not supported"
) from e
SUPPORTS_WINDOWING = False
if ENGINE == "ck":
def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: float = 0.0,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
None,
None,
None,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
softcap,
False,
None,
)[0]
elif ENGINE == "triton":
from .flash_attn_triton import triton_attention
def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
):
if softcap is not None:
raise NotImplementedError("softcap is only available with CK flash attn")
out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
output, _ = triton_attention(
q,
key_cache,
value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
causal,
softmax_scale,
)
return output
else:
raise RuntimeError(f"Unknown attention engine {ENGINE}")

View File

@ -1,3 +0,0 @@
from .hpu import WQLinear
__all__ = ["WQLinear"]

View File

@ -1,134 +0,0 @@
from typing import Optional
import torch
import torch.nn as nn
try:
import habana_frameworks.torch.hpu # noqa: F401
convert_from_uint4 = torch.ops.hpu.convert_from_uint4
except Exception as e:
hpu_import_exception = e
def error_raiser_hpu(*args, **kwargs):
raise ValueError(
f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
)
convert_from_uint4 = error_raiser_hpu
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qzeros.device)
# unpacking columnwise
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
torch.int8 # smallest dtype available
)
iweights = iweights.view(iweights.shape[0], -1)
# unpacking columnwise
if qzeros is not None:
izeros = torch.bitwise_right_shift(
qzeros[:, :, None], shifts[None, None, :]
).to(
torch.int8 # smallest dtype available
)
izeros = izeros.view(izeros.shape[0], -1)
else:
izeros = qzeros
return iweights, izeros
def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
reverse_order_tensor = torch.arange(
iweights.shape[-1],
dtype=torch.int32,
device=izeros.device,
)
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
reverse_order_tensor = reverse_order_tensor.view(-1)
if izeros is not None:
izeros = izeros[:, reverse_order_tensor]
iweights = iweights[:, reverse_order_tensor]
return iweights, izeros
def unpack_weight_and_zeros(qweight, qzeros, bits):
# Unpack the qweight and qzeros tensors
iweight, izeros = unpack_awq(qweight, qzeros, bits)
# Reverse the order of the iweight and izeros tensors
iweight, izeros = reverse_awq_order(iweight, izeros, bits)
# overflow checks
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
return iweight, izeros
def pack_tensor(input, bits=4):
normal = input.to(torch.int32)
q = torch.zeros(
(normal.shape[0], normal.shape[1] // 32 * bits),
dtype=torch.int32,
device=input.device,
)
i = 0
col = 0
while col < q.shape[1]:
for j in range(i, i + (32 // bits)):
q[:, col] |= normal[:, j] << (bits * (j - i))
i += 32 // bits
col += 1
q = q.to(torch.int32)
return q
class WQLinear(nn.Module):
def __init__(
self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.in_features = qweight.shape[0]
self.out_features = qweight.shape[1] * 32 // w_bit
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else self.in_features
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert self.out_features % (32 // self.w_bit) == 0
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.bias = bias
self._preprocessing()
def _preprocessing(self):
device = self.qweight.device
weight, zeros = unpack_weight_and_zeros(
self.qweight.cpu(), self.qzeros.cpu(), self.w_bit
)
self.qweight = pack_tensor(weight).to(device)
self.qzeros = pack_tensor(zeros).to(device)
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,)
x = x.reshape(-1, x.shape[-1])
weights = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
outputs = torch.matmul(x, weights)
outputs = outputs + self.bias if self.bias is not None else outputs
outputs = outputs.reshape(out_shape)
return outputs

View File

@ -0,0 +1,49 @@
# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
from typing import Optional
import torch
import torch.nn as nn
import awq_inference_engine # with CUDA kernels
# class ScaledActivation(nn.Module):
# def __init__(self, module, scales):
# super().__init__()
# self.act = module
# self.scales = nn.Parameter(scales.data)
#
# def forward(self, x):
# return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
class WQLinear(nn.Module):
def __init__(
self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.in_features = qweight.shape[0]
self.out_features = qweight.shape[1] * 32 // w_bit
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else self.in_features
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert self.out_features % (32 // self.w_bit) == 0
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.bias = bias
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,)
out = awq_inference_engine.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)

View File

@ -1,3 +0,0 @@
from .loader import CompressedTensorsLoader
__all__ = ["CompressedTensorsLoader"]

View File

@ -1,169 +0,0 @@
from typing import Any, Dict, List, Union
from compressed_tensors import QuantizationConfig, QuantizationStatus
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization import (
QuantizationScheme,
QuantizationType,
find_name_or_class_matches,
)
from loguru import logger
from pydantic import ValidationError
from torch import nn
from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
Weights,
WeightsLoader,
)
# compressed-tensors can match modules as quantization targets. However,
# they need to be objects rather than classes or class names. Since we
# need to match `Linear` targets, make an instance that can be re-used.
_EMPTY_LINEAR: nn.Module = nn.Linear(0, 0)
class CompressedTensorsLoader(WeightsLoader):
"""Loader for checkpoints stored in the compressed-tensors format."""
def __init__(self, config: Dict[str, Any]):
quantization_config_raw = config.get("quantization_config")
if quantization_config_raw is None:
# `compression_config` was renamed to `quantization_config`; support
# retained for backward compatibility.
quantization_config_raw = config.get("compression_config")
if quantization_config_raw is None:
raise ValueError(
"Checkpoint does not have compressed-tensors configuration"
)
try:
quantization_config = QuantizationConfig.model_validate(
quantization_config_raw
)
except ValidationError as e:
raise ValueError("Cannot parse compressed-tensors configuration") from e
if quantization_config.quantization_status not in (
QuantizationStatus.COMPRESSED,
QuantizationStatus.FROZEN,
):
raise ValueError(
f"Model quantization was not finished, status was: {quantization_config.quantization_status}"
)
self.ignore = (
quantization_config.ignore if quantization_config.ignore is not None else []
)
self.loaders = self._get_target_loaders(quantization_config)
for target, loader in self.loaders.items():
log_once(
logger.info,
f"Using {loader} for compressed-tensors target '{target}'",
)
def get_weights(self, weights: Weights, prefix: str):
loader = self._lookup_loader(prefix)
return loader.get_weights(weights, prefix)
def get_weights_col_packed(
self,
weights: "Weights",
prefix: str,
block_sizes: Union[int, List[int]],
):
loader = self._lookup_loader(prefix)
return loader.get_weights_col_packed(weights, prefix, block_sizes)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
loader = self._lookup_loader(prefixes[0])
return loader.get_multi_weights_col(weights, prefixes, dim)
def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int):
loader = self._lookup_loader(prefixes[0])
return loader.get_multi_weights(weights, prefixes, dim)
def get_weights_row(self, weights: Weights, prefix: str):
loader = self._lookup_loader(prefix)
return loader.get_weights_row(weights, prefix)
def _get_target_loaders(
self, quantization_config: QuantizationConfig
) -> Dict[str, WeightsLoader]:
"""
A compressed-tensors checkpoint can use different quantizations
for different targets. This method returns a dictionary with a
loader per target.
"""
loaders: Dict[str, WeightsLoader] = {}
format = quantization_config.format
for group_name, group in quantization_config.config_groups.items():
# The group configuration can be a string, but does that ever
# happen in a serialized quantization config?
assert isinstance(group, QuantizationScheme)
loader = self._create_loader_for_group(format, group_name, group)
# A quantized parameter group can have multiple targets, add the
# loader for all the targets.
for target in group.targets:
if target in loaders:
raise ValueError(
f"Target '{target} has multiple configured loaders'"
)
loaders[target] = loader
return loaders
def _create_loader_for_group(
self, format: str, group_name: str, group: QuantizationScheme
) -> WeightsLoader:
"""
Find and create a loader for the group with the given quantization
scheme.
"""
# NOTE: we ignore group.output_activations because we don't support
# output quantization yet.
input_activations = group.input_activations
weights = group.weights
if (
format
in {
CompressionFormat.float_quantized.value,
CompressionFormat.naive_quantized.value,
}
and weights is not None
and weights.type == QuantizationType.FLOAT
and weights.num_bits == 8
):
# FP W8A8 or W8A16.
return W8ANFpLoader(input_activations=input_activations, weights=weights)
else:
raise ValueError(
f"Group '{group_name}' has unsupported compressed-tensors configurtion"
)
def _lookup_loader(self, prefix: str) -> WeightsLoader:
"""
Look up the loader to use for a given parameter name (prefix).
"""
if len(find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.ignore)) > 0:
return DefaultWeightsLoader(UnquantizedWeight)
# We currently only handle linear layers, so unconditionally pass
# a `Linear` instance.
targets = find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.loaders.keys())
if len(targets) == 0:
raise ValueError(
f"Cannot find compressed-tensors target for prefix: {prefix}"
)
return self.loaders[targets[0]]

View File

@ -1,253 +0,0 @@
from typing import List, Optional, Union
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
from text_generation_server.layers.fp8 import (
Fp8Weight,
_load_scalar_or_matrix_scale,
requantize_with_max_scale,
)
from text_generation_server.utils.weights import Weights, WeightsLoader
class W8ANFpLoader(WeightsLoader):
"""
Loader for W8A8/W8A16 FP compressed-tensors parameters.
"""
def __init__(
self,
*,
input_activations: Optional[QuantizationArgs],
weights: QuantizationArgs,
):
assert weights.type == QuantizationType.FLOAT and weights.num_bits == 8
# We ignore the `strategy` option which sets the scales to be
# per-tensor, per-channel or per-token. What scales are supported
# is dependent on the kernels used (e.g. cutlass can do tokenwise,
# Torch cannot, and FP8-Marlin does not quantize inputs at all).
# So, instead we try to use the best-possible configuration.
self.load_weight_scale = not weights.dynamic
self.load_input_scale = (
input_activations is not None and not input_activations.dynamic
)
self.force_w8a16 = (
input_activations is not None and input_activations.num_bits == 16
)
def __str__(self) -> str:
def scale_to_str(scale):
return "static" if scale else "dynamic"
quantization_type = f"W8A{16 if self.force_w8a16 else 8}"
return f"{self.__class__.__name__} ({quantization_type}, weight: {scale_to_str(self.load_weight_scale)}, input: {scale_to_str(self.load_input_scale)})"
def get_weights(self, weights: "Weights", prefix: str):
w = weights.get_tensor(f"{prefix}.weight")
weight_scale = None
if self.load_weight_scale:
weight_scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
logical_widths = [w.shape[0]]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
w = weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
weight_scale = None
if self.load_weight_scale:
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if weight_scale.numel() > 1:
weight_scale = weights.get_packed_sharded(
f"{prefix}.weight_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
logical_widths = [w.shape[0]]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
if input_scale.numel() > 1:
input_scale = weights.get_packed_sharded(
f"{prefix}.input_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
input_scale = input_scale.reshape(-1).max()
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
w = [
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
]
shapes = [x.shape for x in w]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
weight_scale = None
if self.load_weight_scale:
weight_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p, shape in zip(prefixes, shapes)
]
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
logical_widths = [x[0] for x in shapes]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes]
shapes = [x.shape for x in w]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
weight_scale = None
if self.load_weight_scale:
weight_scale = [
weights.get_tensor(f"{p}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(shape[0])
for p, shape in zip(prefixes, shapes)
]
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
logical_widths = [x[0] for x in shapes]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = [
weights.get_tensor(f"{p}.input_scale", to_dtype=False)
.reshape(-1)
.expand(shape[0])
for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_weights_row(self, weights: "Weights", prefix: str):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
weight_scale = None
if self.load_weight_scale:
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
logical_widths = [w.shape[0]]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)

View File

@ -0,0 +1,43 @@
from dataclasses import dataclass
import torch
from EETQ import quant_weights, w8_a16_gemm
from text_generation_server.utils.weights import UnquantizedWeight
@dataclass
class EETQWeight(UnquantizedWeight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
try:
from text_generation_server.layers.eetq import EETQLinear
return EETQLinear(self.weight, bias)
except ImportError:
raise ImportError(
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
)
class EETQLinear(torch.nn.Module):
def __init__(
self,
weight,
bias,
) -> None:
super().__init__()
device = weight.device
if weight.dtype != torch.float16:
weight = weight.to(dtype=torch.float16)
weight = torch.t(weight).contiguous().cpu()
weight, scale = quant_weights(weight, torch.int8, False)
self.weight = weight.cuda(device)
self.scale = scale.cuda(device)
self.bias = bias.cuda(device) if bias is not None else None
def forward(self, input: torch.Tensor) -> torch.Tensor:
output = w8_a16_gemm(input, self.weight, self.scale)
output = output + self.bias if self.bias is not None else output
return output

View File

@ -1,293 +1,100 @@
from dataclasses import dataclass
from typing import Optional, Tuple, Type, Union, List
import torch
from dataclasses import dataclass
from typing import Optional, Union, List
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import (
Weight,
WeightsLoader,
UnquantizedWeight,
Weights,
)
from vllm_hpu_extension.ops import scaled_fp8_quant
from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2
quant_dtype: torch.dtype = torch.float8_e4m3fn
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
if is_hpu_gaudi2():
FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max
from text_generation_server.utils.log import log_master, log_once
import importlib.util
def pad_weight(weight, block_size):
"""Pads a matrix to make its dimensions multiples of block_size."""
M, N = weight.shape[-2:]
block_size_m, block_size_n = block_size
pad_M = (block_size_m - M % block_size_m) % block_size_m
pad_N = (block_size_n - N % block_size_n) % block_size_n
if pad_M == 0 and pad_N == 0:
return weight, M, N # No padding needed
padded_weight = torch.nn.functional.pad(
weight, (0, pad_N, 0, pad_M), mode="constant", value=0
)
return padded_weight, M, N # Return original dimensions for unpadding
FBGEMM_MM_AVAILABLE = False
FBGEMM_DYN_AVAILABLE = False
def unpad_weight(weight, original_M, original_N, keep_first_dim=False):
"""Removes padding from the matrix to restore its original shape."""
if (weight.shape[-2] == original_M) and (weight.shape[-1] == original_N):
return weight
if keep_first_dim:
return weight[:, :original_M, :original_N]
else:
return weight[:original_M, :original_N]
def is_fbgemm_gpu_available():
try:
return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
except ModuleNotFoundError:
return False
def pad_block_fp8_weight_naive(weight, weight_scale, block_size):
assert len(block_size) == 2
block_size_m, block_size_n = block_size
weight_scale_m, weight_scale_n = weight_scale.shape[-2:]
weight, orig_M, orig_N = pad_weight(weight, block_size)
M, N = weight.shape[-2:]
assert weight_scale_m == M // block_size_m
assert weight_scale_n == N // block_size_n
return weight, orig_M, orig_N
if is_fbgemm_gpu_available():
if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability()
FBGEMM_MM_AVAILABLE = major == 9
FBGEMM_DYN_AVAILABLE = major >= 8
else:
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
def dynamic_quant(data, single_scale=False):
if single_scale:
scale = ((torch.abs(data)).max() + 1e-8) / FP8_MAX
else:
scale = ((torch.abs(data)).max(dim=-1).values + 1e-8) / FP8_MAX
scale = scale.unsqueeze(-1)
data_fp8 = torch.ops.hpu.cast_to_fp8_v2(
data, 1.0 / scale, False, False, torch.float8_e4m3fn
)[0]
return data_fp8, scale.float()
def dequant_block_fp8_weight_naive(
weight,
weight_scale,
block_size,
dtype=torch.bfloat16,
original_M=None,
original_N=None,
do_unpad=False,
):
if weight_scale is None:
return weight
assert len(block_size) == 2
weight_shape_len = len(weight.shape)
block_size_m, block_size_n = block_size
# mul scale
if weight_shape_len == 2:
weight_scale_m, weight_scale_n = weight_scale.shape
weight_scale = weight_scale.view(weight_scale_m, 1, weight_scale_n, 1)
weight = weight.view(weight_scale_m, block_size_m, weight_scale_n, block_size_n)
if is_hpu_gaudi2():
fake_weight = weight.cpu().to(dtype).to(weight.device)
dequant_weight = fake_weight * weight_scale.to(dtype)
else:
dequant_weight = weight.to(dtype) * weight_scale.to(dtype)
dequant_weight = dequant_weight.view(
weight_scale_m * block_size_m, weight_scale_n * block_size_n
)
keep_first_dim = False
elif weight_shape_len == 3:
fd, weight_scale_m, weight_scale_n = weight_scale.shape
weight_scale = weight_scale.view(fd, weight_scale_m, 1, weight_scale_n, 1)
weight = weight.view(
fd, weight_scale_m, block_size_m, weight_scale_n, block_size_n
)
if is_hpu_gaudi2():
fake_weight = weight.cpu().to(dtype).to(weight.device)
dequant_weight = fake_weight * weight_scale.to(dtype)
else:
dequant_weight = weight.to(dtype) * weight_scale.to(dtype)
dequant_weight = dequant_weight.view(
fd, weight_scale_m * block_size_m, weight_scale_n * block_size_n
)
keep_first_dim = True
else:
raise ValueError("Only support original weight shape is either 2 or 3")
if do_unpad:
dequant_weight = unpad_weight(
dequant_weight, original_M, original_N, keep_first_dim=keep_first_dim
)
return dequant_weight
def apply_block_fp8_linear_hpu_dynamic(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
x_fp8, x_scale = dynamic_quant(input_2d)
output = torch.ops.hpu.fp8_gemm_v2(
x_fp8,
False,
weight,
True,
None,
torch.bfloat16,
x_scale,
weight_scale,
None,
False,
)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
def get_fp8_linear() -> torch.nn.Module:
"""
Return an FP8 linear `Module` that is compatible with the current system.
"""
if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability()
if major == 8:
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
return GPTQMarlinFP8Linear
# On other systems let Torch decide if the hardware supports FP8.
return Fp8Linear
def normalize_e4m3fn_to_native_float8(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return weight, weight_scale, input_scale
def per_tensor_dequantize(
tensor: torch.Tensor,
inv_scale: Union[float, torch.Tensor],
dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
device = tensor.device
dtype = torch.bfloat16
if is_hpu_gaudi2():
# dequant on cpu to avoid nan on gaudi2
tensor = tensor.to("cpu")
fake_qweight = tensor.to(dtype).to(device)
dq_weight = fake_qweight * inv_scale
return dq_weight
def requantize_with_max_scale(
weight: torch.Tensor,
weight_scale: torch.Tensor,
logical_widths: int,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max()
if is_hpu_gaudi2():
max_w_scale = max_w_scale * get_hpu_gaudi2_scale_factor()
start = 0
for idx, logical_width in enumerate(logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(
weight[start:end, :], weight_scale[start:end, :], dtype
)
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
weight_dq, max_w_scale
)
start = end
return weight, max_w_scale_normalized
def fp8_quantize(
weight: torch.Tensor,
scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[torch.Tensor] = None,
qdtype: torch.dtype = torch.float8_e4m3fn,
scalar: bool = False,
weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
):
"""
This function returns a reciprocal of the scale, so that a tensor can be unscaled
by multiplying it with the returned scale. If a scale is given through the `scale`
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
be used without modification).
"""
shape = weight.shape
qweight, scale = scaled_fp8_quant(
weight.reshape(-1, shape[-1]),
scale=scale,
scale_ub=scale_upper_bound,
# TODO: don't do this when we have to use the Torch kernel.
use_per_token_if_dynamic=not scalar,
)
if FBGEMM_DYN_AVAILABLE and not scalar:
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
)
return qweight, scale
return qweight.reshape(shape), scale
# weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(qdtype)
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(qdtype)
scale = scale.float().reciprocal()
return qweight, scale
class HybridFP8UnquantLoader(WeightsLoader):
"""Weight loader that loads FP8 and unquantized Torch tensors."""
def __init__(
self,
activation_scale_ub: Optional[float],
to_fp8: bool,
weight_block_size: Optional[List[int]] = None,
):
def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
self.activation_scale_ub = activation_scale_ub
self.to_fp8 = to_fp8
self.weight_block_size = weight_block_size
def get_weights(self, weights: "Weights", prefix: str):
w = weights.get_tensor(f"{prefix}.weight")
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
scale = weights.get_tensor(f"{prefix}.weight_scale_inv")
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
# FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
scale = scale.reshape(-1).expand(w.shape[0])
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = (
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
.reshape(-1)
.max()
)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
@ -309,7 +116,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if scale.numel() > 1:
scale = weights.get_packed_sharded(
f"{prefix}.weight_scale",
@ -318,29 +124,10 @@ class HybridFP8UnquantLoader(WeightsLoader):
to_dtype=False,
)
scale = scale.reshape(-1).expand(w.shape[0])
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
)
if input_scale.numel() > 1:
input_scale = weights.get_packed_sharded(
f"{prefix}.input_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
input_scale = input_scale.reshape(-1).max()
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
@ -361,110 +148,15 @@ class HybridFP8UnquantLoader(WeightsLoader):
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
scale = [
weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False)
for p in prefixes
]
scale = torch.cat(scale, dim=dim)
scale = scale.to(weights.device)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p, shape in zip(prefixes, shapes)
]
scale = torch.cat(scale, dim=0).reshape(-1)
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes]
shapes = [x.shape for x in w]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
scale = [
weights.get_tensor(f"{p}.weight_scale_inv", to_device=False)
for p in prefixes
]
scale = torch.cat(scale, dim=dim)
scale = scale.to(weights.device)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
scale = [
weights.get_tensor(f"{p}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(shape[0])
for p, shape in zip(prefixes, shapes)
]
scale = torch.cat(scale, dim=0).reshape(-1)
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = [
weights.get_tensor(f"{p}.input_scale", to_dtype=False).reshape(-1)
for p in prefixes
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
@ -477,40 +169,14 @@ class HybridFP8UnquantLoader(WeightsLoader):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
# XXX: Yes the weights is named scale_inv, but corresponds to scale it seems.
scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = (
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
.reshape(-1)
.max()
)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
@ -525,131 +191,95 @@ class Fp8Weight(Weight):
weight: torch.Tensor
dtype: torch.dtype
weight_scale: Optional[torch.Tensor] = None
input_scale: Optional[torch.Tensor] = None
activation_scale_ub: Optional[float] = None
force_w8a16: bool = False
weight_block_size: Optional[List[int]] = None
def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None:
return get_fp8_linear(force_w8a16=self.force_w8a16).from_unquant(
self.weight, bias, self.dtype
)
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
# This is not checked by the fbgemm kernels, but they require contiguous
# memory. Can be non-contiguous when we e.g. expand from scalars.
self.weight_scale = self.weight_scale.contiguous()
return get_fp8_linear(force_w8a16=self.force_w8a16).from_fp8(
weight=self.weight,
scale=self.weight_scale,
dtype=self.dtype,
bias=bias,
input_scale=self.input_scale,
scale_upper_bound=self.activation_scale_ub,
weight_block_size=self.weight_block_size,
return get_fp8_linear().from_fp8(
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
)
class Fp8Linear(torch.nn.Module):
_device_identity_cache = {}
def __init__(
self,
qweight: torch.Tensor,
scale: torch.Tensor,
dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[float] = None,
weight_block_size: Optional[List[int]] = None,
qweight,
scale,
scale_upper_bound,
bias,
dtype,
) -> None:
super().__init__()
if FBGEMM_MM_AVAILABLE:
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
self.dtype = dtype
self.qweight = qweight
self.scale = scale.float()
self.input_scale = input_scale.float() if input_scale is not None else None
self.weight_block_size = weight_block_size
self.scale_upper_bound = scale_upper_bound
self.scale = scale
self.scale_upper_bound = (
torch.tensor(
[scale_upper_bound], dtype=torch.float32, device=qweight.device
)
if scale_upper_bound is not None
else None
)
self.bias = bias if bias is not None else None
@classmethod
def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight, scalar=True)
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
return cls(
qweight=qweight,
scale=scale,
dtype=dtype,
bias=bias,
input_scale=None,
scale_upper_bound=None,
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
)
@classmethod
def from_fp8(
cls,
weight: torch.Tensor,
scale: torch.Tensor,
dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
**kwargs,
) -> "Fp8Linear":
input_scale = kwargs.get("input_scale", None)
scale_upper_bound = kwargs.get("scale_upper_bound", None)
weight_block_size = kwargs.get("weight_block_size", None)
if weight_block_size is not None:
weight, orig_M, orig_N = pad_block_fp8_weight_naive(
weight, scale, weight_block_size
)
weight, scale = dynamic_quant(
dequant_block_fp8_weight_naive(
weight,
scale,
weight_block_size,
original_M=orig_M,
original_N=orig_N,
do_unpad=True,
)
)
scale = scale.squeeze(-1)
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
if FBGEMM_DYN_AVAILABLE:
# fbgemm needs float32 scales.
scale = scale.float()
return cls(
qweight=weight,
scale=scale,
input_scale=input_scale,
scale_upper_bound=scale_upper_bound,
scale_upper_bound=input_scale,
bias=bias,
dtype=dtype,
weight_block_size=weight_block_size,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.weight_block_size is not None or self.input_scale is None:
return apply_block_fp8_linear_hpu_dynamic(
input, self.qweight, self.scale, self.input_scale, self.bias
if FBGEMM_MM_AVAILABLE:
qinput, scale = fp8_quantize(
input, scale_upper_bound=self.scale_upper_bound
)
x_fp8 = torch.ops.hpu.cast_to_fp8_v2(
input, 1.0 / self.input_scale, False, False, torch.float8_e4m3fn
)[0]
return torch.ops.hpu.fp8_gemm_v2(
A=x_fp8,
trans_A=False,
B=self.qweight,
trans_B=True,
D=None,
out_dtype=input.dtype,
A_scale_inv=self.input_scale,
B_scale_inv=self.scale,
y = torch.ops.fbgemm.f8f8bf16_rowwise(
qinput,
self.qweight,
scale,
self.scale,
use_fast_accum=True,
bias=self.bias,
)
return y.to(self.dtype)
qinput, scale = fp8_quantize(input, scalar=True)
output, _ = torch._scaled_mm(
qinput,
self.qweight.t(),
out_dtype=self.dtype,
scale_a=scale,
scale_b=self.scale,
bias=self.bias,
accumulate=False,
)
return output
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
scale = weights.get_tensor(prefix, to_dtype=False)
if scale.numel() > 1:
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
return scale.reshape(-1).expand(shape[0])

View File

@ -1,18 +1,12 @@
import os
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import (
Weight,
Weights,
WeightsLoader,
DefaultWeightsLoader,
)
from .hpu import QuantLinear
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
@dataclass
@ -36,8 +30,13 @@ class GPTQWeight(Weight):
def get_linear(self, bias: torch.Tensor):
if self.use_awq_kernel:
if SYSTEM == "rocm":
raise NotImplementedError(
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
"to use Exllama/GPTQ kernels for AWQ inference."
)
try:
from text_generation_server.layers.awq.quantize import WQLinear
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
return WQLinear(
w_bit=self.bits,
@ -51,7 +50,18 @@ class GPTQWeight(Weight):
raise NotImplementedError(
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
)
elif self.use_exllama:
try:
from text_generation_server.layers.gptq import ExllamaQuantLinear
except ImportError:
raise NotImplementedError(
"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
)
return ExllamaQuantLinear(self, bias)
else:
from text_generation_server.layers.gptq.quant_linear import QuantLinear
return QuantLinear(
self.qweight,
self.qzeros,
@ -77,7 +87,6 @@ class GPTQWeightsLoader(WeightsLoader):
quant_method: str,
quantize: str,
sym: bool,
modules_to_not_convert: List[str],
):
self.bits = bits
self.desc_act = desc_act
@ -85,12 +94,6 @@ class GPTQWeightsLoader(WeightsLoader):
self.quant_method = quant_method
self.quantize = quantize
self.sym = sym
self.modules_to_not_convert = modules_to_not_convert
def is_layer_skipped_quantization(
self, prefix: str, modules_to_not_convert: List[str]
):
return any(module_name in prefix for module_name in modules_to_not_convert)
def get_weights(self, weights: Weights, prefix: str):
self._get_gptq_params(weights)
@ -103,9 +106,6 @@ class GPTQWeightsLoader(WeightsLoader):
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
return DefaultWeightsLoader.get_weights(weights, prefix)
try:
qweight = weights.get_tensor(f"{prefix}.qweight")
except RuntimeError:
@ -118,6 +118,23 @@ class GPTQWeightsLoader(WeightsLoader):
else:
g_idx = None
from text_generation_server.layers.gptq import (
HAS_EXLLAMA,
CAN_EXLLAMA,
GPTQWeight,
)
if use_exllama:
if not HAS_EXLLAMA:
if CAN_EXLLAMA:
log_once(
logger.warning,
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
)
use_exllama = False
else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
qzeros = weights.get_tensor(f"{prefix}.qzeros")
scales = weights.get_tensor(f"{prefix}.scales")
@ -160,10 +177,6 @@ class GPTQWeightsLoader(WeightsLoader):
prefix: str,
block_sizes: Union[int, List[int]],
):
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
return DefaultWeightsLoader.get_weights_col_packed(
weights, prefix, block_sizes
)
try:
qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
@ -215,8 +228,6 @@ class GPTQWeightsLoader(WeightsLoader):
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim)
try:
qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
@ -236,7 +247,14 @@ class GPTQWeightsLoader(WeightsLoader):
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act
from text_generation_server.layers.gptq import HAS_EXLLAMA
use_exllama = (
self.bits == 4
and HAS_EXLLAMA
and self.quantize == "gptq"
and not self.desc_act
)
if self.quantize == "gptq" and self.quant_method == "gptq":
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
@ -276,74 +294,13 @@ class GPTQWeightsLoader(WeightsLoader):
use_exllama=use_exllama,
)
def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int):
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
return DefaultWeightsLoader.get_multi_weights(weights, prefixes, dim)
try:
qweight = torch.cat(
[weights.get_tensor(f"{p}.qweight") for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
)
scales = torch.cat([weights.get_tensor(f"{p}.scales") for p in prefixes], dim=1)
self._get_gptq_params(weights)
qzeros = torch.cat([weights.get_tensor(f"{p}.qzeros") for p in prefixes], dim=1)
use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act
if self.quantize == "gptq" and self.quant_method == "gptq":
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
).to(dtype=torch.int32)
else:
g_idx = None
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq",
use_exllama=use_exllama,
)
def get_weights_row(self, weights: Weights, prefix: str):
self._get_gptq_params(weights)
use_exllama = True
desc_act = self.desc_act
if self.bits != 4:
use_exllama = False
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
return DefaultWeightsLoader.get_weights_row(weights, prefix)
if self.desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
@ -364,8 +321,7 @@ class GPTQWeightsLoader(WeightsLoader):
if g_idx is not None:
if (
not torch.equal(
# Remove g_idx[0] to adapt the check with TP>1.
(g_idx - g_idx[0]).cpu(),
g_idx.cpu(),
torch.tensor(
[i // self.groupsize for i in range(g_idx.shape[0])],
dtype=torch.int32,
@ -376,22 +332,34 @@ class GPTQWeightsLoader(WeightsLoader):
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
use_exllama = False
desc_act = True
from text_generation_server.layers.gptq import (
CAN_EXLLAMA,
HAS_EXLLAMA,
GPTQWeight,
)
if not desc_act and self.groupsize != -1:
if use_exllama:
if not HAS_EXLLAMA:
if CAN_EXLLAMA:
log_once(
logger.warning,
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
)
use_exllama = False
else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
if use_exllama and self.groupsize != -1:
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
if g_idx is not None:
# qzeros, scales sharded, and g_idx must be adjusted accordingly
g_idx = g_idx - g_idx[0]
else:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
scales = weights.get_tensor(f"{prefix}.scales")
if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0]
if self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
@ -424,7 +392,7 @@ class GPTQWeightsLoader(WeightsLoader):
)
def _get_gptq_params(self, weights: Weights):
if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
self.bits = weights.get_tensor("gptq_bits").item()
self.groupsize = weights.get_tensor("gptq_groupsize").item()
self.desc_act = False
@ -432,7 +400,41 @@ class GPTQWeightsLoader(WeightsLoader):
# before the `gptq_sym` setting tensor was added.
self.sym = (
weights.get_tensor("gptq_sym").item()
if weights.has_tensor("gptq_sym")
if weights._has_tensor("gptq_sym")
else False
)
self.quant_method = "gptq"
# Needs to be at the end because circular import.
try:
major, _minor = torch.cuda.get_device_capability()
except Exception:
major = 1
HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm"
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False
elif CAN_EXLLAMA:
try:
if V2:
from text_generation_server.layers.gptq.exllamav2 import (
QuantLinear as ExllamaQuantLinear, # noqa: F401
create_exllama_buffers, # noqa: F401
set_device, # noqa: F401
)
HAS_EXLLAMA = "2"
else:
from text_generation_server.layers.gptq.exllama import (
Ex4bitLinear as ExllamaQuantLinear, # noqa: F401
create_exllama_buffers, # noqa: F401
set_device, # noqa: F401
)
HAS_EXLLAMA = "1"
except ImportError:
pass

View File

@ -0,0 +1,261 @@
# https://github.com/fpgaminer/GPTQ-triton
"""
Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.
"""
import builtins
import math
import time
from typing import Dict
import triton
class Autotuner(triton.KernelInterface):
def __init__(
self,
fn,
arg_names,
configs,
key,
reset_to_zero,
prune_configs_by: Dict = None,
nearest_power_of_two: bool = False,
):
"""
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results
"""
if not configs:
self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
else:
self.configs = configs
self.key_idx = [arg_names.index(k) for k in key]
self.nearest_power_of_two = nearest_power_of_two
self.cache = {}
# hook to reset all required tensor to zeros before relaunching a kernel
self.hook = lambda args: 0
if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
def _hook(args):
for i in self.reset_idx:
args[i].zero_()
self.hook = _hook
self.arg_names = arg_names
# prune configs
if prune_configs_by:
perf_model, top_k = (
prune_configs_by["perf_model"],
prune_configs_by["top_k"],
)
if "early_config_prune" in prune_configs_by:
early_config_prune = prune_configs_by["early_config_prune"]
else:
perf_model, top_k, early_config_prune = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.early_config_prune = early_config_prune
self.fn = fn
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
# as kwargs and by the autotuner
conflicts = meta.keys() & config.kwargs.keys()
if conflicts:
raise ValueError(
f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols."
)
# augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs)
def kernel_call():
if config.pre_hook:
config.pre_hook(self.nargs)
self.hook(args)
self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**current,
)
try:
# In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
# PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
return triton.testing.do_bench(
kernel_call, quantiles=(0.5, 0.2, 0.8), rep=40
)
except triton.OutOfResources:
return [float("inf"), float("inf"), float("inf")]
def run(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
if len(self.configs) > 1:
key = tuple(args[i] for i in self.key_idx)
# This reduces the amount of autotuning by rounding the keys to the nearest power of two
# In my testing this gives decent results, and greatly reduces the amount of tuning required
if self.nearest_power_of_two:
key = tuple([2 ** int(math.log2(x) + 0.5) for x in key])
if key not in self.cache:
# prune configs
pruned_configs = self.prune_configs(kwargs)
bench_start = time.time()
timings = {
config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs
}
bench_end = time.time()
self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
self.configs_timings = timings
config = self.cache[key]
else:
config = self.configs[0]
self.best_config = config
if config.pre_hook is not None:
config.pre_hook(self.nargs)
return self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs,
)
def prune_configs(self, kwargs):
pruned_configs = self.configs
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {
config: self.perf_model(
**self.nargs,
**kwargs,
**config.kwargs,
num_stages=config.num_stages,
num_warps=config.num_warps,
)
for config in pruned_configs
}
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[
:top_k
]
return pruned_configs
def warmup(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
for config in self.prune_configs(kwargs):
self.fn.warmup(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs,
)
self.nargs = None
def autotune(
configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False
):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
.. highlight:: python
.. code-block:: python
@triton.autotune(configs=[
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
],
key=['x_size'] # the two above configs will be evaluated anytime
# the value of x_size changes
)
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
:note: When all the configurations are evaluated, the kernel will run multiple time.
This means that whatever value the kernel updates will be updated multiple times.
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
reset the value of the provided tensor to `zero` before running any configuration.
:param configs: a list of :code:`triton.Config` objects
:type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
:type key: list[str]
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
"""
def decorator(fn):
return Autotuner(
fn,
fn.arg_names,
configs,
key,
reset_to_zero,
prune_configs_by,
nearest_power_of_two,
)
return decorator
def matmul248_kernel_config_pruner(configs, nargs):
"""
The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.
"""
m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16)
n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16)
k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16)
used = set()
for config in configs:
block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"])
block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"])
block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"])
group_size_m = config.kwargs["GROUP_SIZE_M"]
if (
block_size_m,
block_size_n,
block_size_k,
group_size_m,
config.num_stages,
config.num_warps,
) in used:
continue
used.add(
(
block_size_m,
block_size_n,
block_size_k,
group_size_m,
config.num_stages,
config.num_warps,
)
)
yield triton.Config(
{
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
},
num_stages=config.num_stages,
num_warps=config.num_warps,
)

View File

@ -0,0 +1,134 @@
from text_generation_server.layers.gptq import GPTQWeight
import torch
from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
"""Construct Q4Matrix, return handle"""
return make_q4(
qweight, qzeros, scales, g_idx if g_idx is not None else none_tensor, device
)
def ext_q4_matmul(x, q4, q4_width):
"""Matrix multiplication, returns x @ q4"""
outshape = x.shape[:-1] + (q4_width,)
x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], q4_width), dtype=torch.float16, device=x.device)
q4_matmul(x, q4, output)
return output.view(outshape)
MAX_DQ = 1
MAX_INNER = 1
ACT_ORDER = False
DEVICE = None
TEMP_STATE = None
TEMP_DQ = None
def set_device(device):
global DEVICE
DEVICE = device
def create_exllama_buffers(max_total_tokens: int):
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ
assert DEVICE is not None, "call set_device first"
if not ACT_ORDER:
max_total_tokens = 1
# This temp_state buffer is required to reorder X in the act-order case.
temp_state = torch.zeros(
(max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE
)
temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE)
# This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
prepare_buffers(DEVICE, temp_state, temp_dq)
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
TEMP_STATE, TEMP_DQ = temp_state, temp_dq
class Ex4bitLinear(torch.nn.Module):
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(self, weight: GPTQWeight, bias):
super().__init__()
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE
assert weight.bits == 4
self.device = weight.qweight.device
self.qweight = weight.qweight
self.qzeros = weight.qzeros
self.scales = weight.scales
self.g_idx = weight.g_idx.cpu() if weight.g_idx is not None else None
self.bias = bias if bias is not None else None
if self.g_idx is not None and (
(self.g_idx == 0).all()
or torch.equal(
weight.g_idx.cpu(),
torch.tensor(
[i // weight.groupsize for i in range(weight.g_idx.shape[0])],
dtype=torch.int32,
),
)
):
self.empty_g_idx = True
self.g_idx = None
assert self.device.type == "cuda"
assert self.device.index is not None
self.q4 = ext_make_q4(
self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index
)
self.height = weight.qweight.shape[0] * 8
self.width = weight.qweight.shape[1]
# Infer groupsize from height of qzeros
self.groupsize = None
if self.qzeros.shape[0] > 1:
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
if self.groupsize is not None:
assert weight.groupsize == self.groupsize
# Handle act-order matrix
if self.g_idx is not None:
if self.groupsize is None:
raise ValueError("Found group index but no groupsize. What do?")
self.act_order = True
else:
self.act_order = False
DEVICE = self.qweight.device
MAX_DQ = max(MAX_DQ, self.qweight.numel() * 8)
if self.act_order:
MAX_INNER = max(MAX_INNER, self.height, self.width)
ACT_ORDER = True
def forward(self, x):
out = ext_q4_matmul(x, self.q4, self.width)
if self.bias is not None:
out.add_(self.bias)
return out

View File

@ -0,0 +1,267 @@
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from loguru import logger
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.utils.log import log_master
try:
from exllamav2.ext import exllamav2_ext
make_q_matrix = exllamav2_ext.make_q_matrix
gemm_half_q_half = exllamav2_ext.gemm_half_q_half
except ImportError:
log_master(logger.warning, "exllamav2_kernels not installed.")
raise
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
@dataclass
class _ExtraTensors:
"""Additional generated quantizer tensors."""
q_group_map: Optional[torch.Tensor] = None
q_invperm: Optional[torch.Tensor] = None
q_perm: Optional[torch.Tensor] = None
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
"""Matrix multiplication, returns x @ q4"""
output_shape = x.shape[:-1] + (q4_width,)
x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], q4_width), dtype=torch.half, device=x.device)
gemm_half_q_half(x, q_handle, output, force_cuda)
return output.view(output_shape)
def make_group_map(q_groups: torch.Tensor, num_qrows: int):
gr = q_groups.tolist()
group_map = []
num_groups = len(gr) // 2
for i in range(num_groups):
bits = gr[i * 2]
if i < num_groups - 1:
qrows = gr[i * 2 + 3] - gr[i * 2 + 1]
else:
qrows = num_qrows - gr[i * 2 + 1]
rows = qrows * 32 // bits
for j in range(rows):
group_map += [i]
group_map += [rows - j]
return torch.tensor(group_map, dtype=torch.short, device=q_groups.device)
# Create Q matrix
def ext_make_q_matrix(
w: Exl2Weight | GPTQWeight,
extra: _ExtraTensors,
temp_dq,
key: Optional[str] = None,
):
"""
Create Q matrix
"""
# max_dq_size = 512*(1024**2)
# max_dq_rows = max_dq_size // out_features[0]
max_dq_rows = 0
# EXL2
if isinstance(w, Exl2Weight):
extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0])
extra.q_perm = torch.argsort(w.q_invperm).short()
return make_q_matrix(
w.q_weight,
extra.q_perm,
w.q_invperm,
w.q_scale,
w.q_scale_max,
w.q_groups,
extra.q_group_map,
none_tensor, # zeros
none_tensor, # scales
none_tensor, # g_idx
none_tensor, # bias
temp_dq,
max_dq_rows,
)
# GPTQ
elif isinstance(w, GPTQWeight):
if w.scales.dtype == torch.float:
w.scales = w.scales.half()
# GPTQ with g_idx (act_order)
if w.g_idx is not None and not (w.g_idx == 0).all().item():
extra.q_perm = torch.empty(
(w.qweight.shape[0] * 8,),
dtype=torch.short,
device=w.qweight.device,
)
extra.q_invperm = torch.empty_like(extra.q_perm)
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
return make_q_matrix(
w.qweight,
extra.q_perm,
extra.q_invperm,
none_tensor, # q_scale
none_tensor, # q_scale_max
none_tensor, # q_groups
none_tensor, # q_group_map
w.qzeros,
w.scales,
w.g_idx.cpu(),
none_tensor, # bias
temp_dq,
max_dq_rows,
)
# GPTQ without g_idx
else:
return make_q_matrix(
w.qweight,
none_tensor, # q_perm
none_tensor, # q_invperm
none_tensor, # q_scale
none_tensor, # q_scale_max
none_tensor, # q_groups
none_tensor, # q_group_map
w.qzeros,
w.scales,
none_tensor, # g_idx
none_tensor, # bias
temp_dq,
max_dq_rows,
)
else:
RuntimeError("Cannot create handle")
DEVICE = None
LAYERS = []
def set_device(device):
global DEVICE
DEVICE = device
def create_exllama_buffers(max_total_tokens: int):
global LAYERS, DEVICE
# No need to initialize scratch space if there are no layers
# that use ExLLamav2.
if len(LAYERS) == 0:
return
# Find the size of the scratch space.
scratch_bytes = max(
layer.scratch_space_fixed(max_input_len=max_total_tokens, max_batch_size=1)
for layer in LAYERS
)
temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes)
for layer in LAYERS:
layer.post_init(temp_dq)
class QuantLinear(nn.Module):
QUANT_TYPE = "exllamav2"
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(
self,
weight: Exl2Weight | GPTQWeight,
bias: torch.Tensor,
):
super().__init__()
self.q_handle = None
self.q_tensors = weight
self.extra_tensors = _ExtraTensors()
if isinstance(weight, Exl2Weight):
self.infeatures = weight.q_invperm.shape[0]
self.outfeatures = weight.q_weight.shape[1]
elif isinstance(weight, GPTQWeight):
if weight.bits != 4:
raise ValueError(
f"Exllamav2 kernel supports only bits=4, requested bits={weight.bits}. Something is wrong in the model initialization."
)
self.infeatures = weight.qweight.shape[0] // weight.bits * 32
self.outfeatures = weight.qweight.shape[1]
self.padding = -self.outfeatures % 32
self.outfeatures = self.outfeatures + self.padding
self.device = weight.device
self.bias = bias if bias is not None else None
global LAYERS
LAYERS.append(self)
def post_init(self, temp_dq):
device = self.q_tensors.device
assert device.type == "cuda"
assert device.index is not None
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
# We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us,
# and `Memory access fault by GPU node-2` will EAT you.
self.temp_dq = temp_dq
self.q_handle = ext_make_q_matrix(self.q_tensors, self.extra_tensors, temp_dq)
def forward(self, x, force_cuda=False):
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
if self.bias is not None:
output.add_(self.bias)
return output
def temp_dq_size(self):
return self.infeatures * self.outfeatures * 2 + 128
def temp_fwd_size(self, max_input_len, max_batch_size):
return self.outfeatures * max_input_len * max_batch_size * 4 + 128
def scratch_space_fixed(self, max_input_len, max_batch_size):
return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)
class ExLlamaV2DeviceTensors:
device_idx: int
scratch_bytes: int
scratch_idx: int
scratch: torch.tensor = None
def __init__(self, device, scratch_bytes):
self.device = device
self.scratch_bytes = scratch_bytes
def prepare(self):
self.scratch = torch.empty(
(self.scratch_bytes // 2,), dtype=torch.half, device=self.device
)
def get_scratch_slice(self, size_bytes):
if self.scratch is None:
self.prepare()
size_bytes = ((size_bytes + 127) // 128) * 128
size_half = size_bytes // 2
scratch_slice = self.scratch.narrow(0, 0, size_half)
return scratch_slice

View File

@ -1,186 +0,0 @@
import math
import numpy as np
import torch
import torch.nn as nn
try:
convert_from_uint4 = torch.ops.hpu.convert_from_uint4
except Exception as e:
hpu_import_exception = e
def error_raiser_hpu(*args, **kwargs):
raise ValueError(
f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
)
convert_from_uint4 = error_raiser_hpu
def pack_tensor(input, bits=4):
normal = input.to(torch.int32)
q = torch.zeros((normal.shape[0], normal.shape[1] // 32 * bits), dtype=torch.int32)
i = 0
col = 0
while col < q.shape[1]:
for j in range(i, i + (32 // bits)):
q[:, col] |= normal[:, j] << (bits * (j - i))
i += 32 // bits
col += 1
q = q.to(torch.int32)
return q
class QuantLinear(nn.Module):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
super().__init__()
self.register_buffer("qweight", qweight)
self.register_buffer("qzeros", qzeros)
self.register_buffer("scales", scales)
self.register_buffer("g_idx", g_idx)
if bias is not None:
self.register_buffer("bias", bias)
else:
self.bias = None
if bits not in [4]:
raise NotImplementedError("Only 4 bits are supported.")
self.bits = bits
self.maxq = 2**self.bits - 1
self.groupsize = groupsize
self.outfeatures = qweight.shape[1]
self.infeatures = qweight.shape[0] * 32 // bits
self.wf = torch.tensor(
list(range(0, 32, self.bits)), dtype=torch.int32
).unsqueeze(0)
self._preprocessing()
def unpack_zeros_from_cuda_old_format(self):
zeros = torch.bitwise_right_shift(
torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
self.wf.unsqueeze(0),
).to(torch.int16 if self.bits == 8 else torch.int8)
zeros = zeros + 1
zeros = torch.bitwise_and(zeros, (2**self.bits) - 1).to(
self.scales.dtype
) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important.
zeros = zeros.reshape(-1, zeros.shape[1] * zeros.shape[2])
return zeros
def unpack_weight_from_cuda_old_format(self):
weight = torch.bitwise_right_shift(
torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),
self.wf.unsqueeze(-1),
).to(torch.int16 if self.bits == 8 else torch.int8)
weight = torch.bitwise_and(weight, (2**self.bits) - 1)
weight = weight.reshape((weight.shape[0] * weight.shape[1], weight.shape[2]))
return weight
def _preprocessing(self):
orig_device = self.qweight.device
self.qweight = self.qweight.cpu()
weight = self.unpack_weight_from_cuda_old_format()
new_qweight = pack_tensor(weight)
self.qweight = new_qweight.to(orig_device)
# TODO: Support group indexing and remove the check
columns = self.qweight.shape[0]
g_idx_trivial = [i // self.groupsize for i in range(columns)]
g_idx_trivial = torch.tensor(
g_idx_trivial, dtype=torch.int32, device=self.g_idx.device
)
assert torch.equal(
self.g_idx, g_idx_trivial
), "Non-trivial tensor g_idx is not supported"
self.qzeros = self.qzeros.cpu()
zeros = self.unpack_zeros_from_cuda_old_format()
new_qzeros = pack_tensor(zeros)
self.qzeros = new_qzeros.to(orig_device)
@classmethod
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
if bits not in [4]:
raise NotImplementedError("Only 4 bits are supported.")
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
qzeros = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
dtype=torch.int32,
)
scales = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
)
g_idx = torch.tensor(
[i // groupsize for i in range(infeatures)], dtype=torch.int32
)
if bias:
bias = torch.zeros((outfeatures), dtype=torch.float16)
else:
bias = None
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
def pack(self, linear, scales, zeros, g_idx=None):
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
/ self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
x = x.reshape(-1, x.shape[-1])
weight = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
out = torch.matmul(x, weight)
out = out.reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out

View File

@ -0,0 +1,359 @@
import math
import numpy as np
import torch
import torch.nn as nn
from torch.cuda.amp import custom_fwd
import triton
import triton.language as tl
from . import custom_autotune
# code based https://github.com/fpgaminer/GPTQ-triton
@custom_autotune.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=2,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=3,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=2,
num_warps=4,
),
],
key=["M", "N", "K"],
nearest_power_of_two=True,
prune_configs_by={
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
"perf_model": None,
"top_k": None,
},
)
@triton.jit
def matmul_248_kernel(
a_ptr,
b_ptr,
c_ptr,
scales_ptr,
zeros_ptr,
g_ptr,
M,
N,
K,
bits,
maxq,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_scales,
stride_zeros,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, K) float16
B is of shape (K//8, N) int32
C is of shape (M, N) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (
offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = offs_am[:, None] < M
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + (
(offs_k[:, None] // infearure_per_bits) * stride_bk
+ offs_bn[None, :] * stride_bn
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_bn[None, :]
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
shifter = (offs_k % infearure_per_bits) * bits
zeros_shifter = (offs_bn % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, num_pid_k):
g_idx = tl.load(g_ptrs)
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(
scales_ptrs + g_idx[:, None] * stride_scales
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(
zeros_ptrs + g_idx[:, None] * stride_zeros
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1) & maxq # eventually avoid overflow
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
g_ptrs += BLOCK_SIZE_K
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device):
output = torch.empty(
(input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16
)
def grid(META):
return (
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
)
matmul_248_kernel[grid](
input,
qweight,
output,
scales,
qzeros,
g_idx,
input.shape[0],
qweight.shape[1],
input.shape[1],
bits,
maxq,
input.stride(0),
input.stride(1),
qweight.stride(0),
qweight.stride(1),
output.stride(0),
output.stride(1),
scales.stride(0),
qzeros.stride(0),
)
return output
class QuantLinearFunction(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
return output
class QuantLinear(nn.Module):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
super().__init__()
self.register_buffer("qweight", qweight)
self.register_buffer("qzeros", qzeros)
self.register_buffer("scales", scales)
self.register_buffer("g_idx", g_idx)
if bias is not None:
self.register_buffer("bias", bias)
else:
self.bias = None
if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
self.bits = bits
self.maxq = 2**self.bits - 1
self.groupsize = groupsize
self.outfeatures = qweight.shape[1]
self.infeatures = qweight.shape[0] * 32 // bits
@classmethod
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
qzeros = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
dtype=torch.int32,
)
scales = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
)
g_idx = torch.tensor(
[i // groupsize for i in range(infeatures)], dtype=torch.int32
)
if bias:
bias = torch.zeros((outfeatures), dtype=torch.float16)
else:
bias = None
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
def pack(self, linear, scales, zeros, g_idx=None):
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
/ self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
out = QuantLinearFunction.apply(
x.reshape(-1, x.shape[-1]),
self.qweight,
self.scales,
self.qzeros,
self.g_idx,
self.bits,
self.maxq,
)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)

View File

@ -12,7 +12,7 @@ from huggingface_hub import HfApi
from accelerate import init_empty_weights
from text_generation_server.utils import initialize_torch_distributed, Weights
from text_generation_server.utils.hub import weight_files
from text_generation_server.layers.gptq import QuantLinear
from text_generation_server.layers.gptq.quant_linear import QuantLinear
from loguru import logger
from typing import Optional
from text_generation_server.layers.gptq.utils import torch_snr_error
@ -956,24 +956,15 @@ def quantize(
pack(model, quantizers, bits, groupsize)
from safetensors.torch import save_file
from huggingface_hub import split_torch_state_dict_into_shards
from transformers.modeling_utils import shard_checkpoint
state_dict = model.state_dict()
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
max_shard_size = "10GB"
state_dict_split = split_torch_state_dict_into_shards(
state_dict,
filename_pattern="model.safetensors",
max_shard_size=max_shard_size,
shards, index = shard_checkpoint(
state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors"
)
index = None
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
shards = state_dict_split.filename_to_tensors
os.makedirs(output_dir, exist_ok=True)
for shard_file, shard in shards.items():
save_file(

View File

@ -1,6 +1,9 @@
import torch
from torch import nn
from accelerate import init_empty_weights
from text_generation_server.utils.import_utils import (
SYSTEM,
)
# Monkey patching
@ -30,14 +33,69 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps):
torch.nn.LayerNorm.load = load_layer_norm
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
if SYSTEM == "cuda":
import dropout_layer_norm
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
return super().forward(hidden_states), residual
return super(FastLayerNorm, self).forward(hidden_states), residual
else:
(
normed_hidden_states,
residual,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
self.bias,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
if residual is None:
residual = hidden_states
return normed_hidden_states, residual
elif SYSTEM == "rocm":
from vllm._C import ops
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states
return super().forward(hidden_states), residual
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
out = ipex.llm.functional.add_layer_norm(
residual,
hidden_states,
self.weight,
self.bias,
self.eps,
residual is not None,
)
return out, residual if residual is not None else hidden_states
class FastRMSNorm(nn.Module):
@ -53,10 +111,74 @@ class FastRMSNorm(nn.Module):
return cls(weight, eps)
def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(self.weight.dtype), residual
if SYSTEM == "ipex":
out = ipex.llm.functional.add_rms_norm(
residual,
hidden_states,
self.weight,
None,
self.variance_epsilon,
residual is not None,
)
return out, residual if residual is not None else hidden_states
elif hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states, residual
elif SYSTEM == "cuda":
# faster post attention rms norm
(
normed_hidden_states,
res,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states
return normed_hidden_states, res
elif SYSTEM == "rocm":
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states
out = torch.empty_like(hidden_states)
ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)

View File

@ -1,5 +1,21 @@
import torch
from text_generation_server.utils.import_utils import SYSTEM
from torch.nn import functional as F
import os
if SYSTEM == "rocm":
ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
"true",
"1",
)
if ROCM_USE_SKINNY_GEMM:
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(
f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}"
)
class FastLinear(torch.nn.Module):
@ -28,11 +44,83 @@ class FastLinear(torch.nn.Module):
return F.linear(input, self.weight, self.bias)
class FastLinearROCm(torch.nn.Module):
def __init__(
self,
weight,
bias,
) -> None:
super().__init__()
self.weight = torch.nn.Parameter(weight)
if bias is not None:
self.bias = torch.nn.Parameter(bias)
else:
self.bias = None
self.cu_count = torch.cuda.get_device_properties(
device="cuda"
).multi_processor_count
self.use_skinny_gemm = (
ROCM_USE_SKINNY_GEMM
and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName
)
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_tensor(f"{prefix}.weight")
if bias:
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return cls(weight, bias)
def forward(self, inp: torch.Tensor) -> torch.Tensor:
weight = self.weight
bias = self.bias
if (
self.use_skinny_gemm
and inp.dtype == torch.float16
and inp.shape[-1] % 8 == 0
):
batched = False
inp_shape = inp.shape
if inp.dim() == 3:
inp = inp.view(-1, inp_shape[-1])
batched = True
m, n, k = weight.shape[0], inp_shape[0], inp_shape[1]
if m > 8 and n <= 4:
out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
)
_custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
elif m % 4 == 0 and n == 1 and k <= 8192:
out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
)
_custom_C.LLMM1(weight, inp, out, 4)
else:
out = F.linear(inp, weight)
if batched:
out.view(*inp_shape[:-1], out.shape[-1])
if bias is not None:
out = out + bias
return out
return F.linear(inp, self.weight, self.bias)
def get_linear(weight, bias):
# Weights that are loaded through methods that are not
# quantization-aware are still bare tensors. We may want
# to change this in the future.
if isinstance(weight, torch.Tensor):
return FastLinear(weight, bias)
if SYSTEM == "rocm":
return FastLinearROCm(weight, bias)
else:
return FastLinear(weight, bias)
return weight.get_linear(bias)

View File

@ -0,0 +1,15 @@
from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear
from text_generation_server.layers.marlin.gptq import (
GPTQMarlinWeightsLoader,
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
from text_generation_server.layers.marlin.marlin import MarlinWeightsLoader
__all__ = [
"GPTQMarlinFP8Linear",
"GPTQMarlinWeightsLoader",
"MarlinWeightsLoader",
"can_use_gptq_marlin",
"repack_gptq_for_marlin",
]

View File

@ -0,0 +1,140 @@
from typing import Optional
import torch
import torch.nn as nn
from loguru import logger
from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.layers.marlin.gptq import _check_valid_shape
from text_generation_server.layers.marlin.util import (
_check_marlin_kernels,
permute_scales,
)
from text_generation_server.utils.log import log_once
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
MARLIN_TILE_SIZE = 16
class GPTQMarlinFP8Linear(nn.Module):
"""
FP8 GPTQ-Marlin linear layer.
"""
def __init__(
self,
qweight: torch.Tensor,
scales: torch.Tensor,
bias: Optional[torch.Tensor],
) -> None:
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
scales = scales.unsqueeze(0)
if scales.shape[1] == 1:
out_features, in_features = qweight.shape
scales = scales.repeat(1, out_features)
qweight, scales = repack_fp8_for_marlin(qweight, scales)
in_features = qweight.shape[0] * MARLIN_TILE_SIZE
out_features = scales.shape[1]
_check_valid_shape(in_features=in_features, out_features=out_features)
self.qweight = qweight
self.scales = scales
self.bias = bias if bias is not None else None
self.workspace = torch.zeros(
out_features // 64 * 16, dtype=torch.int, device=qweight.device
)
@classmethod
def from_unquant(cls, weight, bias, dtype):
qweight, scales = fp8_quantize(weight)
return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)
@classmethod
def from_fp8(cls, weight, scale, _input_scale, bias, dtype):
return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
A_flat = A.view(-1, A.shape[-1])
C = marlin_kernels.fp8_marlin_gemm(
A_flat,
self.qweight,
self.scales,
self.workspace,
8,
A_flat.shape[0],
self.scales.shape[1],
A_flat.shape[1],
)
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
if self.bias is not None:
C += self.bias
return C
def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
"""
Repack FP8 weights to gptq format (packed int32 elements).
"""
assert fp8_tensor.dtype == torch.float8_e4m3fn
if fp8_tensor.shape[0] % 4 != 0:
raise ValueError(
f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}"
)
# Reshape to prepare for packing
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
# Convert fp8 to uint8 (byte) representation
byte_tensor = reshaped.view(torch.uint8)
# Pack 4 uint8 values into one int32
packed = torch.zeros(
fp8_tensor.shape[0] // 4,
fp8_tensor.shape[1],
dtype=torch.int32,
device=fp8_tensor.device,
)
for i in range(4):
packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8)
return packed
def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor):
"""
Repack FP8 tensor for GPTQ-Marlin.
"""
out_features, in_features = weight.shape
# Torch linear layers weights with shape [out_features, in_features],
# GPTQ-quantized weights use [in_feateres/pack_factor, in_features],
# so transpose before packing.
qweight = pack_fp8_as_int32(weight.t())
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
repacked = marlin_kernels.gptq_marlin_repack(
qweight, perm, in_features, out_features, 8
)
scales = permute_scales(scales)
return repacked, scales

View File

@ -0,0 +1,464 @@
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy
import torch
import torch.nn as nn
from loguru import logger
from text_generation_server.layers.marlin.util import (
_check_marlin_kernels,
marlin_zero_points,
permute_scales,
unpack_cols,
)
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
try:
major, _minor = torch.cuda.get_device_capability()
has_sm_8_0 = major >= 8
except Exception:
has_sm_8_0 = False
GPTQ_MARLIN_BITS = [4, 8]
GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
MARLIN_TILE_SIZE = 16
def can_use_gptq_marlin(
*, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool
) -> bool:
return (
SYSTEM == "cuda"
and marlin_kernels is not None
and has_sm_8_0
and quantize in {"awq", "gptq"}
and quant_method in {"awq", "gptq"}
and bits in GPTQ_MARLIN_BITS
and groupsize in GPTQ_MARLIN_GROUP_SIZES
# We only suppord asymmetric quantization for AWQ.
and (sym or quant_method == "awq")
)
class GPTQMarlinWeightsLoader(WeightsLoader):
"""
Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels.
"""
def __init__(
self,
*,
bits: int,
desc_act: bool,
groupsize: int,
quant_method: str,
quantize: str,
sym: bool,
):
self.bits = bits
self.desc_act = desc_act
self.groupsize = groupsize
self.quant_method = quant_method
self.quantize = quantize
self.sym = sym
def get_weights(self, weights: Weights, prefix: str):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = weights.get_tensor(f"{prefix}.qweight")
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
if not self.sym:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_tensor(f"{prefix}.g_idx")
scales = weights.get_tensor(f"{prefix}.scales")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
try:
qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
)
scales = weights.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes
)
scales = scales.to(dtype=weights.dtype)
if not self.sym:
qzeros = weights.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_tensor(f"{prefix}.g_idx")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
try:
qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
)
scales = torch.cat(
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
if not self.sym:
qzeros = torch.cat(
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
)
def get_weights_row(self, weights: Weights, prefix: str):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
if not self.sym:
if self.desc_act or self.groupsize == -1:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
else:
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
if self.desc_act or self.groupsize == -1:
scales = weights.get_tensor(f"{prefix}.scales")
else:
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
sharded_in_features = weights.process_group.size() > 1
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=sharded_in_features,
)
def _get_gptq_params(self, weights: Weights):
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
self.bits = weights.get_tensor("gptq_bits").item()
self.groupsize = weights.get_tensor("gptq_groupsize").item()
self.desc_act = False
# `server quantize` used asymmetric quantization unconditionally
# before the `gptq_sym` setting tensor was added.
self.sym = (
weights.get_tensor("gptq_sym").item()
if weights._has_tensor("gptq_sym")
else False
)
self.quant_method = "gptq"
@dataclass
class GPTQMarlinWeight(Weight):
"""
Repacked GPTQ Marlin weights.
"""
qweight: torch.Tensor
qzeros: torch.Tensor
scales: torch.Tensor
g_idx: torch.Tensor
perm: torch.Tensor
bits: int
is_full_k: bool
def __post_init__(self):
assert self.qweight.dtype == torch.int32
assert self.scales.dtype == torch.float16
assert self.g_idx.dtype == torch.int32
assert self.perm.dtype == torch.int32
def get_linear(self, bias: torch.Tensor):
return GPTQMarlinLinear(
weight=self,
bias=bias,
)
def repack_gptq_for_marlin(
*,
qweight: torch.Tensor,
qzeros: Optional[torch.Tensor],
scales: torch.Tensor,
g_idx: Optional[torch.Tensor],
bits: int,
desc_act: bool,
groupsize: int,
quant_method: str,
sym: bool,
sharded_infeatures: bool,
) -> GPTQMarlinWeight:
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
_check_marlin_kernels()
assert marlin_kernels is not None
if bits not in GPTQ_MARLIN_BITS:
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
raise RuntimeError(
f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}"
)
if groupsize not in GPTQ_MARLIN_GROUP_SIZES:
supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES)
raise RuntimeError(
f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}"
)
if not (sym or quant_method == "awq"):
raise RuntimeError(
"Repacking GPTQ weights with asymmetric quantization as Marlin is not supported."
)
log_once(logger.info, f"Converting {quant_method} model to Marlin packing format.")
weights_per_int = 32 // bits
in_features = qweight.shape[0]
out_features = qweight.shape[1]
# AWQ uses column packing, GPTQ uses row packing
if quant_method == "awq":
out_features *= weights_per_int
else:
in_features *= weights_per_int
if in_features % groupsize != 0:
raise ValueError(
f"Number of input features ({in_features}) not divisible by group size ({groupsize})"
)
if g_idx is not None and desc_act and groupsize != -1:
perm = torch.argsort(g_idx).to(torch.int)
g_idx = g_idx[perm]
else:
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
if quant_method == "awq":
repacked = marlin_kernels.awq_marlin_repack(
qweight, in_features, out_features, bits
)
if qzeros is not None:
qzeros = awq_to_marlin_zero_points(
qzeros,
in_features // groupsize,
out_features,
bits,
)
else:
repacked = marlin_kernels.gptq_marlin_repack(
qweight, perm, in_features, out_features, bits
)
if qzeros is None:
qzeros = torch.empty(0, dtype=torch.int, device=qweight.device)
scales = permute_scales(scales)
is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)
return GPTQMarlinWeight(
qweight=repacked,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
perm=perm,
bits=bits,
is_full_k=is_full_k,
)
class GPTQMarlinLinear(nn.Module):
"""
Linear layer for GPTQ weights that were converted for the GPTQ-Marlin
kernels.
"""
def __init__(
self,
*,
weight: GPTQMarlinWeight,
bias: Optional[torch.Tensor],
):
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE
out_features = weight.scales.shape[1]
_check_valid_shape(in_features=in_features, out_features=out_features)
self.bits = weight.bits
self.is_full_k = weight.is_full_k
self.qweight = weight.qweight
self.qzeros = weight.qzeros
self.scales = weight.scales
self.g_idx = weight.g_idx
self.perm = weight.perm
if bias is not None:
self.bias = bias
else:
self.bias = None
self.workspace = torch.zeros(
out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
A_flat = A.view(-1, A.shape[-1])
C = marlin_kernels.gptq_marlin_gemm(
A_flat,
self.qweight,
self.scales,
self.qzeros,
self.g_idx,
self.perm,
self.workspace,
self.bits,
A_flat.shape[0],
self.scales.shape[1],
A_flat.shape[1],
self.is_full_k,
self.qzeros.numel() > 0,
True,
)
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
if self.bias is not None:
C += self.bias
return C
def awq_to_marlin_zero_points(
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
# AWQ zero-points are quantized and packed on the column dim.
# In addition, the values are permuted based on dequantizer.
# Here we undo both of these, and then apply marlin permutation
# and pack it back.
q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
# Undo interleaving (use argsort(..) to get inverse perm)
if num_bits == 4:
undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
elif num_bits == 8:
undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
q_zp = q_zp.reshape((-1, size_n)).contiguous()
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
return marlin_zp
def _check_valid_shape(in_features: int, out_features: int):
if (in_features % 128 != 0 or out_features % 64 != 0) and (
in_features % 64 != 0 or out_features % 128 != 0
):
raise ValueError(
f"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features})."
" The shape elements must be divisible by (128, 64) or (64, 128)."
)

View File

@ -0,0 +1,346 @@
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
import torch.nn as nn
from text_generation_server.layers.marlin.util import _check_marlin_kernels
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
class MarlinWeightsLoader(WeightsLoader):
"""Loader for Marlin-quantized weights."""
def __init__(self, *, bits: int, is_marlin_24: bool):
self.bits = bits
self.is_marlin_24 = is_marlin_24
def get_weights(self, weights: "Weights", prefix: str):
"""
Get weights at the given prefix and apply without tensor paralllism.
"""
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
B = weights.get_tensor(f"{prefix}.B_24")
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
)
B_meta = weights.get_tensor(f"{prefix}.B_meta")
s = weights.get_tensor(f"{prefix}.s")
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
try:
B = weights.get_tensor(f"{prefix}.B")
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized."
)
s = weights.get_tensor(f"{prefix}.s")
weight = MarlinWeight(B=B, s=s)
return weight
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
if self.is_marlin_24:
B = weights.get_packed_sharded(
f"{prefix}.B_24", dim=1, block_sizes=block_sizes
)
B_meta = weights.get_packed_sharded(
f"{prefix}.B_meta", dim=1, block_sizes=block_sizes
)
s = weights.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
B = weights.get_packed_sharded(
f"{prefix}.B", dim=1, block_sizes=block_sizes
)
s = weights.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
weight = MarlinWeight(B=B, s=s)
return weight
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
if self.is_marlin_24:
try:
B = torch.cat(
[weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized"
)
B_meta = torch.cat(
[weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1
)
s = torch.cat(
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
try:
B = torch.cat(
[weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized"
)
s = torch.cat(
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
weight = MarlinWeight(B=B, s=s)
return weight
def get_weights_row(self, weights: Weights, prefix: str):
if self.is_marlin_24:
try:
B = weights.get_sharded(f"{prefix}.B_24", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
)
B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0)
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = weights.get_tensor(f"{prefix}.s")
else:
s = weights.get_sharded(f"{prefix}.s", dim=0)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
try:
B = weights.get_sharded(f"{prefix}.B", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized."
)
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = weights.get_tensor(f"{prefix}.s")
else:
s = weights.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s)
return weight
@dataclass
class MarlinWeight(Weight):
"""
Marlin weights.
Attributes:
B (torch.Tensor): int4-quantized weights packed into int32.
s (torch.Tensor): bfloat16/float16 scales.
"""
B: torch.Tensor
s: torch.Tensor
def __post_init__(self):
assert self.B.dtype == torch.int32
assert self.s.dtype in [torch.float16, torch.bfloat16]
def get_linear(self, bias: torch.Tensor):
return MarlinLinear(weight=self, bias=bias)
class MarlinLinear(nn.Module):
def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]):
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE
out_features = weight.s.shape[1]
assert (
in_features % 128 == 0
), f"Number of input features ({in_features}) not divisable by 128"
assert (
out_features % 256 == 0
), f"Number of output features ({out_features}) not divisable by 256"
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
assert groupsize in {
-1,
128,
}, f"Group size must be -1 or 128, was {groupsize}"
self.B = weight.B
self.s = weight.s
if bias is not None:
self.bias = bias
else:
self.bias = None
self.workspace = torch.zeros(
out_features // 64 * 16, dtype=torch.int, device=weight.B.device
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
C = marlin_kernels.marlin_gemm(
A.view(-1, A.shape[-1]),
self.B,
self.s,
self.workspace,
A.shape[0],
self.s.shape[1],
A.shape[1],
)
C = C.reshape(A.shape[:-1] + (self.s.shape[1],))
if self.bias is not None:
C += self.bias
return C
GPTQ_MARLIN_24_MIN_THREAD_N = 128
GPTQ_MARLIN_24_MIN_THREAD_K = 128
GPTQ_MARLIN_24_MAX_PARALLEL = 64
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
MARLIN_TILE_SIZE = 16
@dataclass
class GPTQMarlin24Weight:
"""
GPTQ-Marlin 2:4 weights.
Attributes:
B (torch.Tensor): int4-quantized weights packed into int32.
B_meta (torch.Tensor): metadata for 2:4 sparsity.
s (torch.Tensor): float16 scales.
bits: quantized weight size.
"""
B: torch.Tensor
B_meta: torch.Tensor
s: torch.Tensor
bits: int
def __post_init__(self):
assert self.B.dtype == torch.int32
assert self.B_meta.dtype == torch.int16
assert self.s.dtype == torch.float16
def get_linear(self, bias: torch.Tensor):
return GPTQMarlin24Linear(
weight=self,
bias=bias,
)
class GPTQMarlin24Linear(nn.Module):
def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]):
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
supported_bits = ", ".join(
str(b) for b in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
)
raise RuntimeError(
f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}"
)
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE * 2
out_features = weight.s.shape[1]
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
supported_sizes = ", ".join(
str(b) for b in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
)
raise RuntimeError(
f"Group size {groupsize} is not supported, must be one of: {supported_sizes}"
)
self.bits = weight.bits
weights_per_int32 = 32 // self.bits
assert (
out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0
), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_N} threads"
assert (
out_features % weights_per_int32 == 0
), f"Number of output features ({out_features}) not divisable by weights per int32 ({weights_per_int32})"
assert (
in_features % GPTQ_MARLIN_24_MIN_THREAD_K == 0
), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_K} threads"
if groupsize != -1 and in_features % groupsize != 0:
raise ValueError(
f"Number of input features ({in_features}) not divisable by group size ({groupsize})"
)
self.B = weight.B
self.B_meta = weight.B_meta
self.s = weight.s
if bias is not None:
self.bias = bias
else:
self.bias = None
self.workspace = torch.zeros(
(out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL,
dtype=torch.int,
device=weight.B.device,
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
C = marlin_kernels.gptq_marlin_24_gemm(
A.view(-1, A.shape[-1]),
self.B,
self.B_meta,
self.s,
self.workspace,
self.bits,
A.shape[0],
self.s.shape[1],
A.shape[1],
)
C = C.reshape(A.shape[:-1] + (self.s.shape[1],))
if self.bias is not None:
C += self.bias
return C

View File

@ -0,0 +1,141 @@
import functools
from typing import List, Tuple
import numpy
import torch
from text_generation_server.utils.import_utils import SYSTEM
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
try:
major, _minor = torch.cuda.get_device_capability()
has_sm_8_0 = major >= 8
except Exception:
has_sm_8_0 = False
def _check_marlin_kernels():
if not (SYSTEM == "cuda" and has_sm_8_0):
raise NotImplementedError(
"Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later."
)
if marlin_kernels is None:
raise NotImplementedError(
"marlin is not installed, install it with: pip install server/marlin"
)
# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54
@functools.cache
def get_perms() -> Tuple[List[int], List[int]]:
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single
def permute_scales(scales: torch.Tensor):
scale_perm, scale_perm_single = get_perms()
out_features = scales.shape[1]
if scales.shape[0] == 1:
scales = scales.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
else:
scales = scales.reshape((-1, len(scale_perm)))[:, scale_perm]
return scales.reshape((-1, out_features)).contiguous()
# Functions below are from vLLM
def get_pack_factor(bits: int) -> int:
if 32 % bits != 0:
raise ValueError(f"Cannot {bits} bit values into uint32")
return 32 // bits
def pack_cols(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == (size_k, size_n)
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
for i in range(pack_factor):
q_res |= q_w[:, i::pack_factor] << num_bits * i
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def unpack_cols(
packed_q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
assert packed_q_w.shape == (
size_k,
size_n // pack_factor,
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
packed_q_w.shape, size_k, size_n, pack_factor
)
orig_device = packed_q_w.device
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
mask = (1 << num_bits) - 1
for i in range(pack_factor):
vals = packed_q_w_cpu & mask
packed_q_w_cpu >>= num_bits
q_res[:, i::pack_factor] = vals
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def marlin_zero_points(
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
scale_perm, _ = get_perms()
# Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA
zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
# Interleave column dim (for the dequantize code) and pack it to int32
if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
zp = zp.reshape((-1, size_n)).contiguous()
zp = pack_cols(zp, num_bits, size_k, size_n)
return zp

View File

@ -10,8 +10,13 @@ from text_generation_server.layers import (
TensorParallelRowLinear,
)
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader
from text_generation_server.layers.moe.gptq_marlin import (
GPTQMarlinSparseMoELayer,
can_use_marlin_moe_gemm,
)
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
@ -19,7 +24,12 @@ from text_generation_server.utils.weights import (
UnquantizedWeight,
)
from .fused_moe import fused_topk, grouped_topk
if SYSTEM == "rocm":
from .fused_moe_rocm import grouped_topk
from vllm.model_executor.layers.fused_moe import fused_topk
elif SYSTEM != "ipex":
from moe_kernels.fused_moe import fused_topk, grouped_topk
# NOTE: we are using a protocol here, because multiple inherance is not nice.
# We need `Module`, and `Module` -> some abstract class -> some concrete
@ -42,8 +52,6 @@ class MoELayer(Protocol):
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
hidden_act: str = "silu",
scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None,
): ...
def forward(
@ -73,14 +81,9 @@ class DenseMoELayer(nn.Module):
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
hidden_act: str = "silu",
scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None,
):
super().__init__()
assert scoring_func is None, "scoring func is not handled"
assert e_score_correction_bias is None, "scoring correction bias is not handled"
log_once(
logger.info,
"No fused layers are available for this model type, using (slower) dense MoE layer",
@ -196,27 +199,22 @@ class SparseMoELayer(nn.Module):
topk: int,
topk_group: Optional[int],
weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()
if (
isinstance(weights.loader, DefaultWeightsLoader)
and isinstance(weights.loader.weight_class, UnquantizedWeight)
) or isinstance(weights.loader, HybridFP8UnquantLoader):
if (
isinstance(weights.loader, HybridFP8UnquantLoader)
and weights.loader.to_fp8
):
cls = FP8SparseMoELayer
else:
cls = UnquantizedSparseMoELayer
cls = UnquantizedSparseMoELayer
elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym:
cls = GPTQMarlinSparseMoELayer
else:
raise ValueError(
f"Unsupported weights loader: {type(weights.loader)}, sparse MoE is only supported for unquantized, AWQ, and GPTQ weights"
f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights"
)
log_once(
@ -232,8 +230,6 @@ class SparseMoELayer(nn.Module):
topk=topk,
topk_group=topk_group,
weights=weights,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name,
down_proj_name=down_proj_name,
@ -245,6 +241,17 @@ class SparseMoELayer(nn.Module):
@staticmethod
def is_supported(weights: Weights) -> bool:
return (
isinstance(weights.loader, DefaultWeightsLoader)
and isinstance(weights.loader.weight_class, UnquantizedWeight)
) or isinstance(weights.loader, HybridFP8UnquantLoader)
(
isinstance(weights.loader, DefaultWeightsLoader)
and isinstance(weights.loader.weight_class, UnquantizedWeight)
)
or isinstance(weights.loader, HybridFP8UnquantLoader)
or (
isinstance(weights.loader, GPTQMarlinWeightsLoader)
and can_use_marlin_moe_gemm(
quant_method=weights.loader.quant_method,
quantize=weights.loader.quantize,
sym=weights.loader.sym,
)
)
)

View File

@ -1,268 +0,0 @@
from typing import Optional
import torch
import torch.nn as nn
import os
from text_generation_server.utils.weights import Weights
from text_generation_server.layers.fp8 import (
Fp8Weight,
fp8_quantize,
quant_dtype,
normalize_e4m3fn_to_native_float8,
dynamic_quant,
dequant_block_fp8_weight_naive,
)
from text_generation_server.layers.moe.fused_moe import select_experts
import habana_frameworks.torch as htorch
class FP8SparseMoELayer(nn.Module):
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()
assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"
self.n_expert_group = n_expert_group
self.topk = topk
self.topk_group = topk_group
self.renormalize = renormalize
self.weight_block_size = weights.weights_loader.weight_block_size
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.world_size = weights.process_group.size()
self.rank = weights.process_group.rank()
self.ep_rank = self.rank
self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true"
if self.use_ep:
n_experts = (n_experts + self.world_size - 1) // self.world_size
self.ep_offset = self.ep_rank * n_experts
else:
self.ep_offset = 0
(
self.gate_up_proj,
self.gate_up_proj_weight_scale,
self.gate_up_proj_input_scale,
) = _load_expert_multi_weights_col(
prefix=prefix,
n_experts=n_experts,
gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name,
weights=weights,
use_ep=self.use_ep,
ep_offset=self.ep_offset,
)
self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
_load_expert_weights_row(
prefix=prefix,
n_experts=n_experts,
name=down_proj_name,
weights=weights,
use_ep=self.use_ep,
ep_offset=self.ep_offset,
)
)
if self.weight_block_size is not None:
self.gate_up_proj, self.gate_up_proj_weight_scale = dynamic_quant(
dequant_block_fp8_weight_naive(
self.gate_up_proj,
self.gate_up_proj_weight_scale,
self.weight_block_size,
)
)
self.down_proj, self.down_proj_weight_scale = dynamic_quant(
dequant_block_fp8_weight_naive(
self.down_proj, self.down_proj_weight_scale, self.weight_block_size
)
)
self.gate_up_proj_weight_scale, self.down_proj_weight_scale = (
self.gate_up_proj_weight_scale.squeeze(-1),
self.down_proj_weight_scale.squeeze(-1),
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=gating_output,
use_grouped_topk=self.n_expert_group is not None,
top_k=self.topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.n_expert_group,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
)
total_num_experts = gating_output.size(-1)
x_fp8, x_scale = dynamic_quant(x, single_scale=True)
if self.use_ep:
moe_n_slice = 1
n_expert_slice = (
total_num_experts + self.world_size - 1
) // self.world_size
else:
moe_n_slice = 1
n_expert_slice = (total_num_experts + moe_n_slice - 1) // moe_n_slice
for i in range(moe_n_slice):
min_expert = i * n_expert_slice
max_expert = min((i + 1) * n_expert_slice, total_num_experts)
w13_list_slice = [
self.gate_up_proj[j, ...] for j in range(min_expert, max_expert)
]
w2_list_slice = [
self.down_proj[j, ...] for j in range(min_expert, max_expert)
]
w13_weight_scale = [
self.gate_up_proj_weight_scale[j, ...]
for j in range(min_expert, max_expert)
]
w2_weight_scale = [
self.down_proj_weight_scale[j, ...]
for j in range(min_expert, max_expert)
]
current_hidden_states = torch.ops.hpu.mixture_of_experts(
hidden_states=x_fp8,
expert_routing_table=topk_ids.to(torch.int64),
router_weights=topk_weights.to(x.dtype),
w12=w13_list_slice,
w3=w2_list_slice,
d_scale_hidden_states=x_scale,
d_scale_w12=w13_weight_scale,
d_scale_w3=w2_weight_scale,
permuted_weights=True,
activation="silu",
experts_min=min_expert + self.ep_offset,
experts_max=max_expert + self.ep_offset - 1,
)
htorch.core.mark_step()
if i == 0:
final_hidden_states = current_hidden_states
else:
final_hidden_states.add_(current_hidden_states)
return final_hidden_states
def _load_expert_weights(
get_weight_fn,
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
ep_offset: int = 0,
) -> torch.Tensor:
all_weight = None
all_weight_scales = None
max_input_scale = None
for i in range(n_experts):
weight = get_weight_fn(prefix, i + ep_offset, name, weights)
assert isinstance(weight, Fp8Weight)
if all_weight is None:
all_weight = torch.empty(
(n_experts,) + weight.weight.shape,
dtype=quant_dtype,
device=weight.weight.device,
)
if all_weight_scales is None:
all_weight_scales = torch.empty(
(n_experts,) + weight.weight_scale.shape,
dtype=torch.float32,
device=weight.weight.device,
)
if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}:
all_weight[i], all_weight_scales[i], current_input_scale = (
normalize_e4m3fn_to_native_float8(
weight.weight, weight.weight_scale, weight.input_scale
)
)
if current_input_scale is not None:
if max_input_scale is None or current_input_scale > max_input_scale:
max_input_scale = current_input_scale
else:
all_weight[i], all_weight_scales[i] = fp8_quantize(
weight.weight, scalar=True
)
assert all_weight is not None
return all_weight, all_weight_scales, max_input_scale
def _load_expert_multi_weights_col(
*,
prefix: str,
n_experts: int,
gate_proj_name: str,
up_proj_name: str,
weights: Weights,
use_ep: bool = False,
ep_offset: int = 0,
) -> torch.Tensor:
def get_weight_fn_sharded(prefix, i, name, weights):
return weights.get_multi_weights_col(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
def get_weight_fn(prefix, i, name, weights):
return weights.get_multi_weights(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
return _load_expert_weights(
get_weight_fn if use_ep else get_weight_fn_sharded,
prefix=prefix,
n_experts=n_experts,
name=None,
weights=weights,
ep_offset=ep_offset if use_ep else 0,
)
def _load_expert_weights_row(
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
use_ep: bool = False,
ep_offset: int = 0,
) -> torch.Tensor:
def get_weight_fn_sharded(prefix, i, name, weights):
return weights.get_weights_row(f"{prefix}.{i}.{name}")
def get_weight_fn(prefix, i, name, weights):
return weights.get_weights(f"{prefix}.{i}.{name}")
return _load_expert_weights(
get_weight_fn if use_ep else get_weight_fn_sharded,
prefix=prefix,
n_experts=n_experts,
name=name,
weights=weights,
ep_offset=ep_offset if use_ep else 0,
)

View File

@ -1,131 +0,0 @@
# coding=utf-8
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Optional
import torch
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
gating_output = gating_output.float()
if e_score_correction_bias is not None:
e_score_correction_bias = e_score_correction_bias.float()
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.shape[0]
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
)
else:
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
topk_weights = torch.nn.functional.softmax(
gating_output, dim=1, dtype=torch.float32
)
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
if renormalize:
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
):
# DeekSeekv2 uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
else:
topk_weights, topk_ids = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
return topk_weights, topk_ids

View File

@ -0,0 +1,52 @@
# coding=utf-8
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
import torch.distributed
# TODO: Remove the functions once moe_kernel are built for ROCM
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids

View File

@ -0,0 +1,215 @@
from dataclasses import dataclass
from typing import List, Optional
import torch
import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import Weights
from text_generation_server.layers.marlin.gptq import (
GPTQMarlinWeight,
GPTQMarlinWeightsLoader,
)
if SYSTEM == "cuda":
from moe_kernels.fused_marlin_moe import fused_marlin_moe
else:
fused_marlin_moe = None
try:
major, _minor = torch.cuda.get_device_capability()
has_sm_8_0 = major >= 8
except Exception:
has_sm_8_0 = False
def can_use_marlin_moe_gemm(
*,
quant_method: str,
quantize: str,
sym: bool,
):
return (
SYSTEM == "cuda"
and fused_marlin_moe is not None
and has_sm_8_0
and quantize == "gptq"
and quant_method == "gptq"
and sym
)
@dataclass
class GPTQMarlinMoEWeight:
qweight: torch.Tensor
qzeros: torch.Tensor
scales: torch.Tensor
g_idx: torch.Tensor
perm: torch.Tensor
is_full_k: bool
class GPTQMarlinSparseMoELayer(nn.Module):
"""
MoE layer that uses a fused GPTQ-Marlin kernel.
"""
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()
if not (
isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym
):
raise ValueError(
f"Unsupported weights loader: {weights.loader}, only GPTQMarlinWeightsLoader with symmetric quantization is supported"
)
assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"
self.n_expert_group = n_expert_group
self.topk = topk
self.topk_group = topk_group
self.renormalize = renormalize
self.gate_up_proj = _load_expert_multi_weights_col(
prefix=prefix,
n_experts=n_experts,
names=[gate_proj_name, up_proj_name],
weights=weights,
)
self.down_proj = _load_expert_weights_row(
prefix=prefix, n_experts=n_experts, name=down_proj_name, weights=weights
)
self.bits = weights.loader.bits
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return fused_marlin_moe(
x,
w1=self.gate_up_proj.qweight,
w2=self.down_proj.qweight,
g_idx1=self.gate_up_proj.g_idx,
g_idx2=self.down_proj.g_idx,
perm1=self.gate_up_proj.perm,
perm2=self.down_proj.perm,
w1_scale=self.gate_up_proj.scales,
w2_scale=self.down_proj.scales,
is_full_k1=self.gate_up_proj.is_full_k,
is_full_k2=self.down_proj.is_full_k,
gating_output=gating_output,
topk=self.topk,
renormalize=self.renormalize,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
num_bits=self.bits,
)
def _load_expert_multi_weights_col(
*,
prefix: str,
n_experts: int,
names: List[str],
weights: Weights,
) -> GPTQMarlinMoEWeight:
moe_weight = None
for i in range(n_experts):
weight = weights.get_multi_weights_col(
[f"{prefix}.{i}.{name}" for name in names], 0
)
assert isinstance(weight, GPTQMarlinWeight)
moe_weight = _pack_weight(
n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight
)
assert moe_weight is not None
return moe_weight
def _load_expert_weights_row(
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
) -> GPTQMarlinMoEWeight:
moe_weight = None
for i in range(n_experts):
weight = weights.get_weights_row(
f"{prefix}.{i}.{name}",
)
assert isinstance(weight, GPTQMarlinWeight)
moe_weight = _pack_weight(
n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight
)
assert moe_weight is not None
return moe_weight
def _pack_weight(
*,
n_experts: int,
expert: int,
moe_weight: Optional[GPTQMarlinMoEWeight],
weight: GPTQMarlinWeight,
) -> GPTQMarlinMoEWeight:
if moe_weight is None:
qweight = torch.empty(
(n_experts,) + weight.qweight.shape,
dtype=weight.qweight.dtype,
device=weight.qweight.device,
)
qzeros = torch.empty(
(n_experts,) + weight.qzeros.shape,
dtype=weight.qzeros.dtype,
device=weight.qzeros.device,
)
scales = torch.empty(
(n_experts,) + weight.scales.shape,
dtype=weight.scales.dtype,
device=weight.scales.device,
)
g_idx = torch.empty(
(n_experts,) + weight.g_idx.shape,
dtype=weight.g_idx.dtype,
device=weight.g_idx.device,
)
perm = torch.empty(
(n_experts,) + weight.perm.shape,
dtype=weight.perm.dtype,
device=weight.perm.device,
)
moe_weight = GPTQMarlinMoEWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
perm=perm,
is_full_k=weight.is_full_k,
)
moe_weight.qweight[expert] = weight.qweight
moe_weight.qzeros[expert] = weight.qzeros
moe_weight.scales[expert] = weight.scales
moe_weight.g_idx[expert] = weight.g_idx
moe_weight.perm[expert] = weight.perm
return moe_weight

View File

@ -3,10 +3,13 @@ from typing import Optional
import torch
import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import UnquantizedWeight, Weights
from vllm_hpu_extension.ops import VllmMixtureOfExpertsOp
import habana_frameworks.torch as htorch
import torch.nn.functional as F
if SYSTEM == "rocm":
from vllm.model_executor.layers.fused_moe import fused_moe
elif SYSTEM != "ipex":
from moe_kernels.fused_moe import fused_moe
class UnquantizedSparseMoELayer(nn.Module):
@ -20,8 +23,6 @@ class UnquantizedSparseMoELayer(nn.Module):
topk: int,
topk_group: Optional[int],
weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
@ -36,9 +37,6 @@ class UnquantizedSparseMoELayer(nn.Module):
self.topk = topk
self.topk_group = topk_group
self.renormalize = renormalize
self.weight_block_size = weights.weights_loader.weight_block_size
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.gate_up_proj = _load_expert_multi_weights_col(
prefix=prefix,
@ -55,30 +53,31 @@ class UnquantizedSparseMoELayer(nn.Module):
weights=weights,
)
self.MoeOp = VllmMixtureOfExpertsOp(n_experts, 0, n_experts - 1)
for i in range(n_experts):
self.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
self.MoeOp.w2_list[i].set_weight(self.down_proj[i])
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
htorch.core.mark_step()
routing_weights = F.softmax(gating_output, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(
routing_weights, self.topk, dim=-1
)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(x.dtype)
if SYSTEM == "rocm":
return fused_moe(
x,
self.gate_up_proj,
self.down_proj,
gating_output,
self.topk,
renormalize=self.renormalize,
inplace=True,
)
final_hidden_states = self.MoeOp(
hidden_states=x,
expert_routing_table=selected_experts,
router_weights=routing_weights,
permuted_weights=True,
activation="silu",
return fused_moe(
x,
w1=self.gate_up_proj,
w2=self.down_proj,
gating_output=gating_output,
topk=self.topk,
renormalize=self.renormalize,
inplace=True,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
return final_hidden_states.view(-1, x.shape[1])
def _load_expert_multi_weights_col(
*,

View File

@ -2,10 +2,14 @@ import os
import math
import torch
from torch import nn
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
apply_rotary_pos_emb,
)
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "cuda":
import rotary_emb
elif SYSTEM == "rocm":
from vllm._C import ops
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
def _create_inv_freq(dim, base, device):
@ -26,7 +30,7 @@ def _get_rope_config(config):
class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq, scaling_factor, max_position_embeddings):
def __init__(self, inv_freq, scaling_factor):
super().__init__()
self.inv_freq = inv_freq
self._seq_len_cached = 0
@ -36,9 +40,6 @@ class PositionRotaryEmbedding(nn.Module):
self._sin_k_cached = None
self.scaling_factor = scaling_factor
self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, inv_freq.device, max_position_embeddings
)
def forward(
self,
@ -47,41 +48,40 @@ class PositionRotaryEmbedding(nn.Module):
cos: torch.Tensor,
sin: torch.Tensor,
):
num_tokens = query.shape[0]
head_size = query.shape[-1]
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
# to query hidden dimension, so the original tensors need to be
# expanded
# GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
# and expansion of cos/sin tensors via concatenation
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
cos = torch.cat((cos, cos), dim=-1)
sin = torch.cat((sin, sin), dim=-1)
rotary_dim = cos.shape[-1]
query_shape = query.shape
query = query.view(num_tokens, -1, head_size)
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
# Such controlflows may add some overhead.
if SYSTEM == "cuda":
rotary_dim = cos.shape[-1]
q1 = query[..., :rotary_dim]
q2 = query[..., rotary_dim : 2 * rotary_dim]
key_shape = key.shape
key = key.view(num_tokens, -1, head_size)
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
k1 = key[..., :rotary_dim]
k2 = key[..., rotary_dim : 2 * rotary_dim]
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif SYSTEM == "rocm":
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
head_size = query.shape[-1]
# Inplace operation, updating query and key.
ops.rotary_embedding(query, key, head_size, cos, sin, True)
elif SYSTEM == "ipex":
ipex.llm.functional.rotary_embedding(
query, key, sin, cos, query.size(-1), True
)
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)
@classmethod
def static(cls, config, dim, base, device):
inv_freq = _create_inv_freq(dim, base, device)
scaling_factor = None
rope_scaling = _get_rope_config(config)
if not hasattr(config, "max_position_embeddings") and hasattr(
config, "max_seq_len"
):
# handling for dbrx
config.max_position_embeddings = config.max_seq_len
if rope_scaling is not None:
# `rope_type` is now standard in transformers, but some existing models
# have `type` instead.
@ -89,17 +89,6 @@ class PositionRotaryEmbedding(nn.Module):
if rope_type == "linear":
pass
elif rope_type == "default":
pass
elif rope_type == "mrope":
mrope_section = rope_scaling["mrope_section"]
if mrope_section is not None:
return RotaryPositionEmbeddingMultimodalSections(
inv_freq,
scaling_factor,
mrope_section,
config.max_position_embeddings,
)
elif rope_type == "dynamic":
scaling_factor = rope_scaling["factor"]
return DynamicPositionRotaryEmbedding(
@ -120,7 +109,7 @@ class PositionRotaryEmbedding(nn.Module):
],
)
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
return cls(inv_freq, scaling_factor)
elif rope_type == "yarn":
scaling_factor = rope_scaling["factor"]
@ -196,13 +185,12 @@ class PositionRotaryEmbedding(nn.Module):
long_inv_freq=long_inv_freq,
scaling_factor=scaling_factor,
original_max_position_embeddings=original_max_position_embeddings,
max_position_embeddings=config.max_position_embeddings,
)
else:
raise NotImplementedError(
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
)
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
return cls(inv_freq, scaling_factor)
@classmethod
def load(cls, config, prefix, weights):
@ -248,7 +236,7 @@ class PositionRotaryEmbedding(nn.Module):
raise NotImplementedError(
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
)
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
return cls(inv_freq, scaling_factor)
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
@ -269,7 +257,17 @@ class PositionRotaryEmbedding(nn.Module):
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
def get_cos_sin(self, position_ids: torch.Tensor):
def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype):
"""
Return cos and sin for the asked position ids
"""
if SYSTEM == "rocm":
# For RoCm, we always use float cos/sin to avoid a cast.
# For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26
# But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.
dtype = torch.float32
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids)
@ -285,7 +283,6 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
long_inv_freq,
scaling_factor,
original_max_position_embeddings,
max_position_embeddings,
):
super(PositionRotaryEmbedding, self).__init__()
self.short_inv_freq = short_inv_freq
@ -298,9 +295,6 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
self._cos_k_cached = None
self._sin_k_cached = None
self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, short_inv_freq.device, max_position_embeddings
)
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
@ -354,9 +348,6 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
self._cos_k_cached = None
self._sin_k_cached = None
self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, short_inv_freq.device, max_position_embeddings
)
def _update_cos_sin_cache(self, dtype, device, seqlen):
if (
@ -392,7 +383,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = _create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
super().__init__(inv_freq, scaling_factor)
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
@ -470,6 +461,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
mscale_all_dim: float,
):
inv_freq = _create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor)
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
@ -484,7 +476,6 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
/ get_mscale(self.scaling_factor, mscale_all_dim)
* self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
@ -555,50 +546,3 @@ def apply_llama3_scaling(
new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
def __init__(
self,
inv_freq: torch.Tensor,
scaling_factor: float,
sections: list,
max_position_embeddings,
):
self.sections = sections
self._cos_cached = None
self._sin_cached = None
self.section_indices = (
torch.arange(len(self.sections))
.repeat_interleave(torch.tensor(self.sections))
.view(1, 1, -1)
.to(inv_freq.device)
)
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
def _update_cos_sin_cache(
self, dtype: torch.dtype, device: torch.device, seqlen: int
):
# always cache the cos/sin for the full sequence length to avoid
# recomputing if the sequence length is smaller than the cached one
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
self._sections = self.section_indices.expand(seqlen, -1, -1)
def get_cos_sin(
self,
position_ids: torch.Tensor,
):
slen = position_ids.shape[0]
cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])
return cos, sin

View File

@ -2,8 +2,10 @@ import torch
from torch.nn import functional as F
from typing import Iterable, List
from text_generation_server.layers.linear import get_linear, FastLinear
from text_generation_server.utils.import_utils import SYSTEM
import habana_frameworks.torch as htorch
if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
class LayerConcat(torch.nn.Module):
@ -88,10 +90,14 @@ class TensorParallelHead(SuperLayer):
local_out = gather_input.T
torch.mm(input, self.linear.weight.T, out=local_out)
htorch.core.mark_step()
torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
if SYSTEM == "ipex":
ipex.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
else:
torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
if input.shape[0] == 1:
return world_out
@ -101,9 +107,10 @@ class TensorParallelHead(SuperLayer):
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
htorch.core.mark_step()
torch.distributed.all_gather(world_output, output, group=self.process_group)
if SYSTEM == "ipex":
ipex.distributed.all_gather(world_output, output, group=self.process_group)
else:
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)
return world_output
@ -195,11 +202,10 @@ class TensorParallelRowLinear(SuperLayer):
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
out = super().forward(input)
if self.process_group.size() > 1 and reduce:
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
# (which is required for tensor parallel HPUGraph inference)
htorch.core.mark_step()
torch.distributed.all_reduce(out, group=self.process_group)
if SYSTEM == "ipex":
ipex.distributed.all_reduce(out, group=self.process_group)
else:
torch.distributed.all_reduce(out, group=self.process_group)
return out
@ -236,9 +242,8 @@ class TensorParallelEmbedding(torch.nn.Module):
)
out = torch.nn.functional.embedding(input, self.weight)
if self.reduce and self.process_group.size() > 1:
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
# (which is required for tensor parallel HPUGraph inference)
htorch.core.mark_step()
torch.distributed.all_reduce(out, group=self.process_group)
if SYSTEM == "ipex":
ipex.distributed.all_reduce(out, group=self.process_group)
else:
torch.distributed.all_reduce(out, group=self.process_group)
return out

View File

@ -1,24 +1,32 @@
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables
import torch
import os
from loguru import logger
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download, HfApi
from typing import Optional
from pathlib import Path
from typing import List, Dict
import enum
# Needed to properly setup habana_frameworks
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
PhiMoEConfig,
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.bloom import BLOOM
from text_generation_server.models.starcoder import StarCoder
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
# from text_generation_server.models.custom_modeling.mllama import (
# MllamaForConditionalGeneration,
# )
from text_generation_server.utils.adapter import (
AdapterParameters,
build_layer_weight_lookup,
@ -27,331 +35,10 @@ from text_generation_server.utils.adapter import (
)
from text_generation_server.adapters.lora import LoraWeights
from text_generation_server.utils.log import log_master
__all__ = [
"Model",
"CausalLM",
"Seq2SeqLM",
"get_model_with_lora_adapters",
]
VLM_BATCH_TYPES = set()
FLASH_ATTENTION = True
try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLM
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
FlashDeepseekV2ForCausalLM,
DeepseekV2Config,
)
from text_generation_server.models.custom_modeling.flash_deepseek_v3_modeling import (
FlashDeepseekV3ForCausalLM,
DeepseekV3Config,
)
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_llama4_modeling import (
Llama4ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
FlashCohereForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gemma3_modeling import (
Gemma3ForConditionalGeneration,
FlashGemma3ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
FlashDbrxForCausalLM,
DbrxConfig,
)
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
RWConfig,
FlashRWForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
FlashGPTNeoXForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM,
)
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
from text_generation_server.models.custom_modeling.flash_mllama import (
FlashMllamaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_llava_next import (
FlashLlavaNextForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
FlashSantacoderForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
FlashStarcoder2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_qwen3_modeling import (
Qwen3ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_qwen3_moe_modeling import (
Qwen3MoeForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
FlashMixtralForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
FlashGPT2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gptj_modeling import (
FlashGPTJForCausalLM,
)
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.idefics3 import (
Idefics3ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.qwen2_vl import (
Qwen2VLForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.qwen2_5_vl import (
Qwen2_5VLForConditionalGeneration,
Qwen2_5_VLConfig,
Qwen2_5_VLProcessor,
)
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
except ImportError as e:
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
SUPPORTS_WINDOWING = False
FLASH_ATTENTION = False
VLM_BATCH_TYPES = set()
if FLASH_ATTENTION:
__all__.append(FlashCausalLM)
from text_generation_server.models.flash_vlm_causal_lm import (
FlashVlmCausalLMBatch,
)
VLM_BATCH_TYPES = {
FlashVlmCausalLMBatch,
FlashMllamaCausalLMBatch,
}
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
__all__.append(VLM_BATCH_TYPES)
class ModelType(enum.Enum):
DEEPSEEK_V2 = {
"type": "deepseek_v2",
"name": "Deepseek V2",
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
}
DEEPSEEK_V3 = {
"type": "deepseek_v3",
"name": "Deepseek V3",
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V3",
}
IDEFICS2 = {
"type": "idefics2",
"name": "Idefics 2",
"url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
"multimodal": True,
}
IDEFICS3 = {
"type": "idefics3",
"name": "Idefics 3",
"url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3",
"multimodal": True,
}
LLAVA_NEXT = {
"type": "llava_next",
"name": "Llava Next (1.6)",
"url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
"multimodal": True,
}
LLAMA = {
"type": "llama",
"name": "Llama",
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
}
LLAMA4 = {
"type": "llama4",
"name": "Llama4",
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
}
PHI3 = {
"type": "phi3",
"name": "Phi 3",
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
}
GRANITE = {
"type": "granite",
"name": "Granite",
"url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct",
}
GEMMA = {
"type": "gemma",
"name": "Gemma",
"url": "https://huggingface.co/google/gemma-7b",
}
PALIGEMMA = {
"type": "paligemma",
"name": "PaliGemma",
"url": "https://huggingface.co/google/paligemma-3b-pt-224",
}
GEMMA2 = {
"type": "gemma2",
"name": "Gemma2",
"url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
}
GEMMA3 = {
"type": "gemma3",
"name": "Gemma3",
"url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d",
}
GEMMA3_TEXT = {
"type": "gemma3_text",
"name": "Gemma3 Text",
"url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d",
}
COHERE = {
"type": "cohere",
"name": "Cohere",
"url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
}
DBRX = {
"type": "dbrx",
"name": "Dbrx",
"url": "https://huggingface.co/databricks/dbrx-instruct",
}
MAMBA = {
"type": "mamba",
"name": "Mamba",
"url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
}
MISTRAL = {
"type": "mistral",
"name": "Mistral",
"url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407",
}
MIXTRAL = {
"type": "mixtral",
"name": "Mixtral",
"url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
}
GPT_BIGCODE = {
"type": "gpt_bigcode",
"name": "Gpt Bigcode",
"url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
}
PHI = {
"type": "phi",
"name": "Phi",
"url": "https://huggingface.co/microsoft/phi-1_5",
}
PHI_MOE = {
"type": "phimoe",
"name": "PhiMoe",
"url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
}
BAICHUAN = {
"type": "baichuan",
"name": "Baichuan",
"url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
}
FALCON = {
"type": "falcon",
"name": "Falcon",
"url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
}
STARCODER2 = {
"type": "starcoder2",
"name": "StarCoder 2",
"url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
}
QWEN2 = {
"type": "qwen2",
"name": "Qwen 2",
"url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
}
QWEN2_VL = {
"type": "qwen2_vl",
"name": "Qwen 2 VL",
"url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d",
}
QWEN2_5_VL = {
"type": "qwen2_5_vl",
"name": "Qwen 2.5 VL",
"url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e",
}
QWEN3 = {
"type": "qwen3",
"name": "Qwen 3",
"url": "https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f",
}
QWEN3_MOE = {
"type": "qwen3_moe",
"name": "Qwen 3 Moe",
"url": "https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f",
}
GALACTICA = {
"type": "galactica",
"name": "Galactica",
"url": "https://huggingface.co/facebook/galactica-120b",
}
SANTACODER = {
"type": "santacoder",
"name": "SantaCoder",
"url": "https://huggingface.co/bigcode/santacoder",
}
GPT2 = {
"type": "gpt2",
"name": "Gpt2",
"url": "https://huggingface.co/openai-community/gpt2",
}
GPT_NEOX = {
"type": "gpt_neox",
"name": "Gpt Neox",
"url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
}
GPTJ = {
"type": "gptj",
"name": "Gptj",
"url": "https://huggingface.co/EleutherAI/gpt-j-6b",
}
MLLAMA = {
"type": "mllama",
"name": "Mllama",
"url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
"multimodal": True,
}
__GLOBALS = locals()
for data in ModelType:
__GLOBALS[data.name] = data.value["type"]
SDP_ON_BF16 = int(os.environ.get("SDP_ON_BF16", 0))
# Disable gradients
torch.set_grad_enabled(False)
@ -364,11 +51,10 @@ def get_model(
quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[torch.dtype],
kv_cache_dtype: Optional[str],
trust_remote_code: bool,
max_input_tokens: int,
) -> Model:
global FLASH_ATTENTION
adapt_transformers_to_gaudi()
if speculate is not None:
set_speculate(speculate)
@ -490,456 +176,38 @@ def get_model(
model_type = config_dict["model_type"]
if kv_cache_dtype == "fp8_e4m3fn":
kv_cache_dtype = torch.float8_e4m3fn
elif kv_cache_dtype == "fp8_e5m2":
kv_cache_dtype = torch.float8_e5m2
else:
kv_cache_dtype = dtype
if model_type == "gpt_bigcode":
return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
if FLASH_ATTENTION:
if model_type == DEEPSEEK_V2:
head_size = max(
config_dict.get("qk_nope_dim", 128)
+ config_dict.get("qk_rope_dim", 64),
config_dict.get("v_head_dim", 128),
)
return FlashCausalLM(
model_id=model_id,
model_class=FlashDeepseekV2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
default_dtype=torch.bfloat16,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DeepseekV2Config,
head_size=head_size,
)
elif model_type == DEEPSEEK_V3:
head_size = max(
config_dict.get("qk_nope_dim", 128)
+ config_dict.get("qk_rope_dim", 64),
config_dict.get("v_head_dim", 128),
)
return FlashCausalLM(
model_id=model_id,
model_class=FlashDeepseekV3ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
default_dtype=torch.bfloat16,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DeepseekV3Config,
head_size=head_size,
)
if model_type == "bloom":
return BLOOM(
model_id=model_id,
revision=revision,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif (
model_type == GPT_BIGCODE
or model_type == GPT2
and model_id.startswith("bigcode/")
):
return FlashCausalLM(
model_id=model_id,
model_class=FlashSantacoderForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
num_kv_heads=1,
)
elif model_type == GPT2:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGPT2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GPTJ:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGPTJForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GPT_NEOX:
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
GPTNeoXConfig,
)
if model_type == "llava_next":
return VlmCausalLM(
model_class=LlavaNextForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=None,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
return FlashCausalLM(
model_id=model_id,
model_class=FlashGPTNeoXForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=GPTNeoXConfig,
)
elif model_type == PHI:
return FlashCausalLM(
model_id=model_id,
model_class=FlashPhiForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == PHI_MOE:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
config_class=PhiMoEConfig,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == LLAMA4:
print(f"Llama4 model detected: {model_id}")
return FlashVlmCausalLM(
model_id=model_id,
model_class=Llama4ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
support_chunking=False,
)
elif model_type == BAICHUAN:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GEMMA:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemmaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GEMMA2:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemma2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GEMMA3:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Gemma3ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
support_chunking=False,
)
elif model_type == GEMMA3_TEXT:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemma3ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == COHERE:
return FlashCausalLM(
model_id=model_id,
model_class=FlashCohereForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == DBRX:
return FlashCausalLM(
model_id=model_id,
model_class=FlashDbrxForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Dbrx works better in bfloat16.
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DbrxConfig,
)
elif (
model_type in ["RefinedWeb", "RefinedWebModel", FALCON]
and not sharded
and not config_dict.get("alibi", False)
):
return FlashCausalLM(
model_id=model_id,
model_class=FlashRWForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"],
},
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=RWConfig,
)
elif model_type == MISTRAL:
return FlashCausalLM(
model_id=model_id,
model_class=FlashMistralForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == MIXTRAL:
return FlashCausalLM(
model_id=model_id,
model_class=FlashMixtralForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == STARCODER2:
return FlashCausalLM(
model_id=model_id,
model_class=FlashStarcoder2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == QWEN2:
return FlashCausalLM(
model_id=model_id,
model_class=Qwen2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == QWEN2_VL:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Qwen2VLForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# TODO: Fix bug in rust image_text_replacement implementation
support_chunking=False,
)
elif model_type == QWEN2_5_VL:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Qwen2_5VLForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=Qwen2_5_VLConfig,
processor_class=Qwen2_5_VLProcessor,
# TODO: Fix bug in rust image_text_replacement implementation
support_chunking=False,
)
elif model_type == QWEN3:
return FlashCausalLM(
model_id=model_id,
model_class=Qwen3ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == QWEN3_MOE:
return FlashCausalLM(
model_id=model_id,
model_class=Qwen3MoeForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == MLLAMA:
return FlashMllamaCausalLM(
model_id=model_id,
model_class=FlashMllamaForConditionalGeneration,
batch_class=FlashMllamaCausalLMBatch,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
support_chunking=False,
)
elif model_type == IDEFICS2:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Idefics2ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
)
elif model_type == IDEFICS3:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Idefics3ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
processor_kwargs={"size": {"longest_edge": 1456}},
)
elif model_type == PALIGEMMA:
return FlashVlmCausalLM(
model_id=model_id,
model_class=PaliGemmaForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == LLAVA_NEXT:
return FlashVlmCausalLM(
model_class=FlashLlavaNextForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
)
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise ValueError(f"Unsupported model type {model_type}")
@ -954,7 +222,6 @@ def get_model_with_lora_adapters(
quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[torch.dtype],
kv_cache_dtype: Optional[str],
trust_remote_code: bool,
max_input_tokens: int,
adapter_to_index: Dict[str, int],
@ -968,7 +235,6 @@ def get_model_with_lora_adapters(
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
max_input_tokens,
)

View File

@ -0,0 +1,52 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import torch
from typing import Optional, Type
from transformers import PreTrainedTokenizerBase
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
class BloomCausalLMBatch(CausalLMBatch):
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super().from_pb(
pb=pb,
tokenizer=tokenizer,
dtype=dtype,
device=device,
)
batch.keys_head_dim_last = False
return batch
class BLOOM(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(BLOOM, self).__init__(
model_id=model_id,
revision=revision,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return BloomCausalLMBatch

File diff suppressed because it is too large Load Diff

View File

@ -377,7 +377,7 @@ class BloomAttention(nn.Module):
past_value.view(-1, *past_value.shape[-2:]),
)
if CUSTOM_KERNELS_ENABLED and attention_mask.shape[-1] < 4096:
if CUSTOM_KERNELS_ENABLED:
assert self.training is False, "Only foward pass was implemented"
assert (
attention_mask.shape[-1] < 4096
@ -580,7 +580,7 @@ class BloomPreTrainedModel(PreTrainedModel):
@staticmethod
def _convert_to_bloom_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]],
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))

View File

@ -28,11 +28,10 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
reshape_and_cache,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -40,6 +39,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
@ -47,12 +47,11 @@ from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.utils.weights import UnquantizedWeight
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
apply_rotary_pos_emb,
)
import habana_frameworks.torch as htorch
if SYSTEM == "cuda":
import dropout_layer_norm
else:
dropout_layer_norm = None
class CohereRotary(PositionRotaryEmbedding):
@ -64,25 +63,38 @@ class CohereRotary(PositionRotaryEmbedding):
sin: torch.Tensor,
):
# Such controlflows may add some overhead.
num_tokens = query.shape[0]
head_size = query.shape[-1]
rope_mode = RotaryPosEmbeddingMode.PAIRWISE
sin = torch.repeat_interleave(sin, 2, dim=-1)
cos = torch.repeat_interleave(cos, 2, dim=-1)
rotary_dim = cos.shape[-1]
query_shape = query.shape
query = query.view(num_tokens, -1, head_size)
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
if SYSTEM == "cuda":
import rotary_emb
key_shape = key.shape
key = key.view(num_tokens, -1, head_size)
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
q1 = query[..., ::2]
q2 = query[..., 1::2]
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
k1 = key[..., ::2]
k2 = key[..., 1::2]
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif SYSTEM == "rocm":
from vllm._C import ops
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
head_size = query.shape[-1]
# Inplace operation, updating query and key.
ops.rotary_embedding(query, key, head_size, cos, sin, False)
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
ipex.llm.functional.rotary_embedding(
query, key, sin, cos, query.size(-1), False
)
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)
class CohereLayerNorm(nn.Module):
@ -95,18 +107,49 @@ class CohereLayerNorm(nn.Module):
self.eps = eps
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(
if hidden_states.shape[-1] > 8192 or SYSTEM != "cuda":
hidden_states = hidden_states.reshape(
-1, self.weight.shape[0], self.weight.shape[1]
)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
hidden_states_minus_mean = hidden_states - mean
variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps)
hidden_states = self.weight.to(torch.float32) * hidden_states
hidden_states = hidden_states.view(-1, self.weight.shape[1])
return hidden_states.to(input_dtype)
(
hidden_states,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
None,
self.ones,
None,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
# Required to apply one weight matrix per head
hidden_states = hidden_states.view(
-1, self.weight.shape[0], self.weight.shape[1]
)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
hidden_states_minus_mean = hidden_states - mean
variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps)
hidden_states = self.weight.to(torch.float32) * hidden_states
hidden_states = self.weight * hidden_states
hidden_states = hidden_states.view(-1, self.weight.shape[1])
return hidden_states.to(input_dtype)
return hidden_states
def load_attention(config, prefix, weights):
@ -160,14 +203,18 @@ class FlashCohereAttention(torch.nn.Module):
prefix: str,
config,
weights,
rotary_emb,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = rotary_emb
self.rotary_emb = CohereRotary.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
)
self.softmax_scale = self.head_size**-0.5
@ -182,7 +229,6 @@ class FlashCohereAttention(torch.nn.Module):
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.use_qk_norm = config.use_qk_norm
if self.use_qk_norm:
@ -218,9 +264,10 @@ class FlashCohereAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
qkv = self.query_key_value(hidden_states)
query, key, value = qkv.split(
@ -244,35 +291,30 @@ class FlashCohereAttention(torch.nn.Module):
self.rotary_emb(query, key, cos, sin)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
return self.o_proj(
@ -321,14 +363,11 @@ class CohereMLP(nn.Module):
class FlashCohereLayer(nn.Module):
def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):
def __init__(self, prefix: str, layer_id, config, weights):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = FlashCohereAttention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
rotary_emb=rotary_emb,
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.mlp = CohereMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
@ -347,9 +386,10 @@ class FlashCohereLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -360,9 +400,10 @@ class FlashCohereLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
mlp_output = self.mlp(normed_hidden_states)
@ -384,12 +425,6 @@ class FlashCohereModel(torch.nn.Module):
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
rotary_emb = CohereRotary.static(
config=config,
dim=config.hidden_size // config.num_attention_heads,
base=config.rope_theta,
device=weights.device,
)
self.layers = nn.ModuleList(
[
FlashCohereLayer(
@ -397,7 +432,6 @@ class FlashCohereModel(torch.nn.Module):
layer_id,
config,
weights,
rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
@ -418,24 +452,21 @@ class FlashCohereModel(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: torch.Tensor,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -444,12 +475,11 @@ class FlashCohereModel(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
@ -486,9 +516,11 @@ class FlashCohereForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -497,9 +529,10 @@ class FlashCohereForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -20,15 +20,17 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM != "ipex":
from vllm.model_executor.layers.fused_moe import fused_moe
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
reshape_and_cache,
Seqlen,
HPUPagedAttentionMetadata,
PREFILL_IN_KV_CACHE,
)
from text_generation_server.layers import (
FastLinear,
@ -44,8 +46,6 @@ from text_generation_server.layers.rotary import (
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
from vllm_hpu_extension.ops import DynamicFusedMOE
import habana_frameworks.torch as htorch
class DbrxAttentionConfig(PretrainedConfig):
@ -263,7 +263,6 @@ class DbrxAttention(torch.nn.Module):
prefix: str,
config,
weights,
rotary_emb,
):
super().__init__()
self.clip_qkv = config.attn_config.clip_qkv
@ -271,7 +270,12 @@ class DbrxAttention(torch.nn.Module):
self.hidden_size = config.d_model
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = rotary_emb
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.attn_config.rope_theta,
device=weights.device,
)
self.softmax_scale = self.head_size**-0.5
@ -286,7 +290,6 @@ class DbrxAttention(torch.nn.Module):
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
@ -306,9 +309,10 @@ class DbrxAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
qkv = self.query_key_value(hidden_states)
if self.clip_qkv is not None:
@ -326,35 +330,30 @@ class DbrxAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -366,17 +365,13 @@ class DbrxNormAttentionNorm(nn.Module):
prefix: str,
config,
weights,
rotary_emb,
):
super().__init__()
self.norm_1 = FastLayerNorm.load_no_bias(
prefix=f"{prefix}.norm_1", weights=weights, eps=1e-5
)
self.self_attn = DbrxAttention(
prefix=f"{prefix}.attn",
config=config,
weights=weights,
rotary_emb=rotary_emb,
prefix=f"{prefix}.attn", config=config, weights=weights
)
self.norm_2 = FastLayerNorm.load_no_bias(
prefix=f"{prefix}.norm_2",
@ -392,9 +387,10 @@ class DbrxNormAttentionNorm(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
normed_hidden_states, res = self.norm_1(hidden_states, residual)
@ -405,9 +401,10 @@ class DbrxNormAttentionNorm(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
# faster post attention rms norm
@ -485,15 +482,18 @@ class BlockSparseMoE(nn.Module):
self.process_group = weights.process_group
self.hpu_fused_moe = DynamicFusedMOE(self.num_experts)
for i in range(self.num_experts):
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.wv1[i])
self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.w2[i])
def forward(self, x: torch.Tensor) -> torch.Tensor:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(x)
out = self.hpu_fused_moe(x, router_logits, self.top_k)
out = fused_moe(
x,
self.wv1,
self.w2,
router_logits,
self.top_k,
renormalize=self.moe_normalize_expert_weights,
inplace=True,
)
# Reduce sum
if self.process_group.size() > 1:
@ -601,15 +601,12 @@ class DenseMoE(nn.Module):
class DbrxLayer(nn.Module):
def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):
def __init__(self, prefix: str, layer_id, config, weights):
super().__init__()
prefix = f"{prefix}.blocks.{layer_id}"
self.attn = DbrxNormAttentionNorm(
prefix=f"{prefix}.norm_attn_norm",
config=config,
weights=weights,
rotary_emb=rotary_emb,
prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights
)
moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE
@ -623,9 +620,10 @@ class DbrxLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
# Self Attention
attn_output, attn_res = self.attn(
@ -635,9 +633,10 @@ class DbrxLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
moe_output = self.moe(attn_output)
@ -652,12 +651,6 @@ class DbrxModel(torch.nn.Module):
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.wte", weights=weights
)
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.d_model // config.n_heads,
base=config.attn_config.rope_theta,
device=weights.device,
)
self.layers = nn.ModuleList(
[
@ -666,7 +659,6 @@ class DbrxModel(torch.nn.Module):
layer_id,
config,
weights,
rotary_emb,
)
for layer_id in range(config.n_layers)
]
@ -685,23 +677,20 @@ class DbrxModel(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids)
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -710,12 +699,11 @@ class DbrxModel(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
@ -744,9 +732,11 @@ class FlashDbrxForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -755,9 +745,10 @@ class FlashDbrxForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -33,15 +33,20 @@ from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
set_block_mapping,
HPUPagedAttentionMetadata,
reshape_and_cache,
)
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import Weights
import habana_frameworks.torch as htorch
if SYSTEM == "rocm":
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
class DeepseekV2Config(PretrainedConfig):
@ -156,7 +161,6 @@ class DeepseekV2Attention(torch.nn.Module):
prefix: str,
config,
weights: Weights,
rotary_emb,
):
super().__init__()
self.num_heads = config.num_attention_heads
@ -168,7 +172,13 @@ class DeepseekV2Attention(torch.nn.Module):
self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
self.value_head_size = config.v_head_dim
self.head_pad_size = max(self.head_size, self.value_head_size)
self.rotary_emb = rotary_emb
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.qk_rope_head_dim,
base=config.rope_theta,
device=weights.device,
)
mscale = get_mscale(
self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
@ -222,8 +232,6 @@ class DeepseekV2Attention(torch.nn.Module):
),
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
)
@ -252,10 +260,11 @@ class DeepseekV2Attention(torch.nn.Module):
cos: torch.Tensor,
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache: KVCache,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
):
if self.q_lora_rank is None:
query = self.q_proj(hidden_states)
@ -312,35 +321,30 @@ class DeepseekV2Attention(torch.nn.Module):
value, (0, self.head_pad_size - self.value_head_size), value=0
)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
# Remove padding.
@ -383,11 +387,27 @@ class DeepseekV2MLP(nn.Module):
self.quantize = config.quantize
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
)
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
and hidden_states.dtype == torch.float16
and hidden_states.shape[0] == 1
and not self.quantize
):
out = torch.empty(
hidden_states.shape[0],
self.intermediate_size,
dtype=hidden_states.dtype,
device="cuda",
)
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
return self.down_proj(out, reduce=reduce)
else:
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
)
class DeepseekV2MoE(nn.Module):
@ -454,7 +474,7 @@ class DeepseekV2MoE(nn.Module):
class DeepseekV2Layer(nn.Module):
def __init__(self, prefix, layer_id, config, weights, rotary_emb):
def __init__(self, prefix, layer_id, config, weights):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
@ -462,7 +482,6 @@ class DeepseekV2Layer(nn.Module):
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
if (
@ -501,9 +520,10 @@ class DeepseekV2Layer(nn.Module):
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache,
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
):
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -514,9 +534,10 @@ class DeepseekV2Layer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
# faster post attention rms norm
@ -537,12 +558,6 @@ class DeepseekV2Model(torch.nn.Module):
prefix=f"{prefix}.embed_tokens", weights=weights
)
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.qk_rope_head_dim,
base=config.rope_theta,
device=weights.device,
)
self.layers = nn.ModuleList(
[
DeepseekV2Layer(
@ -550,7 +565,6 @@ class DeepseekV2Model(torch.nn.Module):
layer_id,
config,
weights,
rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
@ -569,24 +583,20 @@ class DeepseekV2Model(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -595,12 +605,11 @@ class DeepseekV2Model(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
@ -626,9 +635,11 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -637,9 +648,10 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -1,723 +0,0 @@
# coding=utf-8
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Type
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from text_generation_server.layers import (
FastLinear,
SpeculativeHead,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
get_linear,
Fp8Linear,
)
from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention_mla,
set_block_mapping,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.weights import Weights
import habana_frameworks.torch as htorch
def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor:
if isinstance(layer, Fp8Linear):
eye = torch.eye(
layer.qweight.shape[-1], dtype=torch.bfloat16, device=layer.qweight.device
)
dequant_weights = layer(eye)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
class DeepseekV3Config(PretrainedConfig):
def __init__(
self,
vocab_size=102400,
hidden_size=4096,
intermediate_size=11008,
moe_intermediate_size=1407,
num_hidden_layers=30,
num_attention_heads=32,
num_key_value_heads=32,
n_shared_experts=2,
n_routed_experts=160,
ep_size=1,
routed_scaling_factor=1.0,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
topk_method="gready",
n_group=8,
topk_group=3,
num_experts_per_tok=6,
moe_layer_freq=1,
first_k_dense_replace=0,
norm_topk_prob=False,
scoring_func="softmax",
aux_loss_alpha=0.001,
seq_aux=True,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=100000,
eos_token_id=100001,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
self.aux_loss_alpha = aux_loss_alpha
self.seq_aux = seq_aux
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
if tie_word_embeddings:
raise ValueError(
"tie_word_embeddings is not supported for Deepseek V2 models."
)
if ep_size != 1:
raise ValueError(
f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}"
)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class DeepseekV3Attention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights: Weights,
rotary_emb,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.kv_lora_rank = config.kv_lora_rank
self.q_lora_rank = config.q_lora_rank
self.qk_nope_head_dim = config.qk_nope_head_dim
self.qk_rope_head_dim = config.qk_rope_head_dim
self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
self.value_head_size = config.v_head_dim
self.head_pad_size = max(self.head_size, self.value_head_size)
self.rotary_emb = rotary_emb
mscale = get_mscale(
self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
)
self.softmax_scale = self.head_size**-0.5 * mscale * mscale
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
if self.q_lora_rank is None:
self.q_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_proj",
weights=weights,
bias=config.attention_bias,
)
else:
self.q_a_proj = get_linear(
weight=weights.get_weights(f"{prefix}.q_a_proj"),
bias=(
weights.get_tensor(f"{prefix}.q_a_proj.bias")
if config.attention_bias
else None
),
)
self.q_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.q_a_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.q_b_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_b_proj",
weights=weights,
bias=config.attention_bias,
)
self.kv_a_proj_with_mqa = get_linear(
weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"),
bias=(
weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias")
if config.attention_bias
else None
),
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.kv_b_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.kv_b_proj",
weights=weights,
bias=config.attention_bias,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj.linear).T
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.value_head_size,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.value_head_size], dim=-1
)
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)
def _q_proj_and_k_up_proj(self, x):
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
q_nope, q_pe = (
q_proj(x)
.view(-1, self.num_heads, self.head_size)
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
return ql_nope.transpose(0, 1), q_pe
def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.value_head_size)
return self.o_proj(x)
def forward(
self,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache: KVCache,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
if self.q_lora_rank is None:
hidden_states_or_q_c = hidden_states
else:
hidden_states_or_q_c = self.q_a_layernorm(self.q_a_proj(hidden_states))[0]
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, key_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
kv_c_normed = self.kv_a_layernorm(compressed_kv.contiguous())[0]
# Prefill
if cu_seqlen_prefill is not None:
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
query = q_proj(hidden_states_or_q_c)
query = query.view(-1, self.num_heads, self.head_size)
query_nope, query_pe = torch.split(
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
else:
query_nope, query_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
batch_size, heads, head_dim = query_pe.shape
query_pe = (
query_pe.view(batch_size, heads, head_dim // 2, 2)
.transpose(2, 3)
.reshape(batch_size, heads, head_dim)
)
batch_size, heads, head_dim = key_pe.shape
key_pe = (
key_pe.view(batch_size, heads, head_dim // 2, 2)
.transpose(2, 3)
.reshape(batch_size, heads, head_dim)
)
self.rotary_emb(query_pe, key_pe, cos, sin)
latent_vec_k = torch.concat(
(kv_c_normed, key_pe.view(-1, self.qk_rope_head_dim)), dim=-1
)
latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)
latent_vec_k = latent_vec_k.unflatten(0, (slots.size(0), -1))
kv_cache.store(
key=latent_vec_k,
value=None,
slots=slots,
kv_scales=self.kv_scales,
)
if cu_seqlen_prefill is not None:
kv = self.kv_b_proj(kv_c_normed).view(
-1,
self.num_key_value_heads,
self.qk_nope_head_dim + self.value_head_size,
)
key_nope, value = torch.split(
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
)
query[..., self.qk_nope_head_dim :] = query_pe
key = torch.empty_like(query)
key[..., : self.qk_nope_head_dim] = key_nope
key[..., self.qk_nope_head_dim :] = key_pe
# We need to pad the heads because Flash Attention does not support
# qk and v with different head sizes.
query = torch.nn.functional.pad(
query, (0, self.head_pad_size - self.head_size), value=0
)
key = torch.nn.functional.pad(
key, (0, self.head_pad_size - self.head_size), value=0
)
value = torch.nn.functional.pad(
value, (0, self.head_pad_size - self.value_head_size), value=0
)
# flash attention
attn_output = attention(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
)
attn_output = attn_output[..., : self.value_head_size]
return self.o_proj(
attn_output.reshape(-1, self.num_heads * self.value_head_size)
)
else:
# Decode
query = torch.cat([query_nope, query_pe], dim=-1)
attn_output = paged_attention_mla(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
kv_lora_rank=self.kv_lora_rank,
)
attn_output = self._v_up_proj_and_o_proj(attn_output)
return attn_output
class DeepseekV3MLP(nn.Module):
def __init__(self, prefix: str, config, weights, intermediate_size: int):
super().__init__()
self.hidden_act = config.hidden_act
if self.hidden_act != "silu":
# Bail out because MoE only supports silu.
raise NotImplementedError(
"Currently only `silu` is supported as an activation for Deepseek V2."
)
self.act = ACT2FN[self.hidden_act]
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.intermediate_size = intermediate_size // weights.process_group.size()
# TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
)
class DeepseekV3MoE(nn.Module):
def __init__(
self,
prefix,
config: DeepseekV3Config,
moe_layer_cls: Type[MoELayer],
weights,
):
super().__init__()
self.hidden_dim = config.hidden_size
self.moe_intermediate_size = (
config.moe_intermediate_size // weights.process_group.size()
)
self.routed_scaling_factor = config.routed_scaling_factor
# Gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
if config.topk_method == "noaux_tc":
self.gate.e_score_correction_bias = torch.zeros(
config.n_routed_experts, device=weights.device
)
else:
self.gate.e_score_correction_bias = None
self.moe_layer = moe_layer_cls(
prefix=f"{prefix}.experts",
n_experts=config.n_routed_experts,
n_expert_group=config.n_group,
renormalize=config.norm_topk_prob,
topk=config.num_experts_per_tok,
topk_group=config.topk_group,
weights=weights,
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
)
assert isinstance(self.moe_layer, MoELayer)
if config.n_shared_experts is not None:
self.shared_experts = DeepseekV3MLP(
prefix=f"{prefix}.shared_experts",
config=config,
weights=weights,
intermediate_size=config.moe_intermediate_size
* config.n_shared_experts,
)
else:
self.shared_experts = None
self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.shared_experts is not None:
shared_output = self.shared_experts(x, reduce=False)
else:
shared_output = None
router_logits = self.gate(x)
out = self.moe_layer(x, gating_output=router_logits)
if shared_output is not None:
out = out + shared_output
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out.view(*x.shape)
class DeepseekV3Layer(nn.Module):
def __init__(self, prefix, layer_id, config, weights, rotary_emb):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = DeepseekV3Attention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
if (
config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace
and layer_id % config.moe_layer_freq == 0
):
moe_layer_cls = (
SparseMoELayer
if SparseMoELayer.is_supported(weights)
else DenseMoELayer
)
self.mlp = DeepseekV3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
else:
self.mlp = DeepseekV3MLP(
prefix=f"{prefix}.mlp",
config=config,
weights=weights,
intermediate_size=config.intermediate_size,
)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
# faster post attention rms norm
normed_attn_res_output, residual = self.post_attention_layernorm(
attn_output, residual
)
output = self.mlp(normed_attn_res_output)
return output, residual
class DeepseekV3Model(torch.nn.Module):
def __init__(self, prefix: str, config, weights: Weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.qk_rope_head_dim,
base=config.rope_theta,
device=weights.device,
)
self.layers = nn.ModuleList(
[
DeepseekV3Layer(
prefix,
layer_id,
config,
weights,
rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashDeepseekV3ForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights: Weights):
super().__init__()
self.model = DeepseekV3Model(
"model" if not prefix else f"{prefix}.model", config, weights
)
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -28,9 +28,8 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
reshape_and_cache,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -41,13 +40,12 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class Gemma2Config(PretrainedConfig):
@ -166,14 +164,7 @@ def _load_gqa(config, prefix: str, weights):
class FlashGemma2Attention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
layer_id,
causal: bool,
is_sliding: bool,
rotary_emb,
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
):
super().__init__()
self.num_heads = config.num_attention_heads
@ -183,7 +174,13 @@ class FlashGemma2Attention(torch.nn.Module):
self.window_size = config.sliding_window
else:
self.window_size = -1
self.rotary_emb = rotary_emb
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
)
# self.softmax_scale = self.head_size**-0.5
self.softmax_scale = config.query_pre_attn_scalar**-0.5
@ -211,7 +208,6 @@ class FlashGemma2Attention(torch.nn.Module):
],
process_group=weights.process_group,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load(
config,
@ -238,10 +234,11 @@ class FlashGemma2Attention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
@ -256,24 +253,19 @@ class FlashGemma2Attention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
causal=self.causal,
window_size_left=self.window_size,
softcap=self.softcap,
)
@ -281,14 +273,14 @@ class FlashGemma2Attention(torch.nn.Module):
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
max_s,
softcap=self.softcap,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
window_size_left=self.window_size,
)
return self.o_proj(
@ -356,14 +348,7 @@ class Gemma2MLP(nn.Module):
class FlashGemma2Layer(nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
layer_id,
causal: bool,
is_sliding: bool,
rotary_emb,
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
):
super().__init__()
self.self_attn = FlashGemma2Attention(
@ -373,7 +358,6 @@ class FlashGemma2Layer(nn.Module):
layer_id=layer_id,
causal=causal,
is_sliding=is_sliding,
rotary_emb=rotary_emb,
)
self.mlp = Gemma2MLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
@ -406,10 +390,11 @@ class FlashGemma2Layer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
hpu_attention_meta,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -420,10 +405,11 @@ class FlashGemma2Layer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
hpu_attention_meta,
)
# faster post attention rms norm
@ -445,13 +431,6 @@ class FlashGemma2Model(torch.nn.Module):
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.head_dim,
base=config.rope_theta,
device=weights.device,
)
self.layers = nn.ModuleList(
[
FlashGemma2Layer(
@ -461,7 +440,6 @@ class FlashGemma2Model(torch.nn.Module):
layer_id=layer_id,
causal=causal,
is_sliding=layer_id % 2 == 0,
rotary_emb=rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
@ -480,26 +458,21 @@ class FlashGemma2Model(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
adapter_data: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -508,13 +481,12 @@ class FlashGemma2Model(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
max_s,
adapter_data,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
@ -557,9 +529,11 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -569,10 +543,11 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -1,755 +0,0 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from typing import Optional, List, Tuple
import copy
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
get_linear,
#
SpeculativeHead,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
import torch
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.utils.weights import UnquantizedWeight
from transformers.activations import ACT2FN
from text_generation_server.layers.attention import (
paged_attention,
attention,
Seqlen,
set_block_mapping,
HPUPagedAttentionMetadata,
)
import habana_frameworks.torch as htorch
ATTENTION_TYPE_GLOBAL = "global"
ATTENTION_TYPE_LOCAL = "local_sliding"
class Gemma3FastRMSNorm(FastRMSNorm):
@classmethod
def load(cls, prefix: str, weights, eps=1e-6):
dtype = weights.dtype
weights.dtype = torch.float32
weight = weights.get_tensor(f"{prefix}.weight") + 1
weights.dtype = dtype
new = cls(weight, eps)
new.dtype = dtype
return new
# perform the multiplication in full precision and downcast after
def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = hidden_states * self.weight
return hidden_states.to(self.dtype), residual
def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
def _load_gqa(config, prefix: str, weights):
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
)
if isinstance(weight, UnquantizedWeight):
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.head_dim
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias=None))
class FlashGemma3Attention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
layer_id,
causal: bool,
is_sliding: bool,
local_rotary_emb,
global_rotary_emb,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.head_size = config.head_dim
self.causal = causal
if is_sliding:
self.window_size = config.sliding_window
self.rotary_emb = local_rotary_emb
else:
self.window_size = -1
self.rotary_emb = global_rotary_emb
self.softmax_scale = (
config.query_pre_attn_scalar**-0.5
if config.query_pre_attn_scalar is not None
else None
)
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.softcap = None # config.attn_logit_softcapping
query_key_value = load_attention(config, prefix, weights)
self.query_key_value = TensorParallelMultiAdapterLinear.load(
query_key_value,
layer_id,
["q_proj", "k_proj", "v_proj"],
sizes=[
self.head_size * config.num_attention_heads,
self.head_size * config.num_key_value_heads,
self.head_size * config.num_key_value_heads,
],
process_group=weights.process_group,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
layer_id,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
self.q_norm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps
)
self.k_norm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps
)
self.enable_gqa = self.num_heads != self.num_key_value_heads
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
kv = kv.view(-1, 2, self.num_key_value_heads * self.head_size)
key = kv[:, 0]
value = kv[:, 1]
query = query.reshape(-1, self.head_size)
key = key.reshape(-1, self.head_size)
query, _ = self.q_norm(query.contiguous())
key, _ = self.k_norm(key.contiguous())
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_key_value_heads, self.head_size)
value = value.view(-1, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, key, cos, sin)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
attn_output = attention(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
window_size_left=self.window_size,
softcap=self.softcap,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
softcap=self.softcap,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
window_size_left=self.window_size,
)
return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class Gemma3MLP(nn.Module):
def __init__(self, prefix, config, weights, layer_id):
super().__init__()
act = config.hidden_activation
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
# Fuse gate and up proj
gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
layer_id,
["gate_proj", "up_proj"],
sizes=[
config.intermediate_size,
config.intermediate_size,
],
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
layer_id,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class FlashGemma3Layer(nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
layer_id,
causal: bool,
is_sliding: bool,
local_rotary_emb,
global_rotary_emb,
):
super().__init__()
self.self_attn = FlashGemma3Attention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
layer_id=layer_id,
causal=causal,
is_sliding=is_sliding,
local_rotary_emb=local_rotary_emb,
global_rotary_emb=global_rotary_emb,
)
self.mlp = Gemma3MLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
)
self.input_layernorm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.pre_feedforward_layernorm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.pre_feedforward_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.post_feedforward_layernorm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.post_feedforward_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
)
# faster post attention rms norm
normed_attn_res_output, _ = self.post_attention_layernorm(attn_output)
normed_attn_res_output = normed_attn_res_output + res
res = normed_attn_res_output
pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)
mlp_output = self.mlp(pre_normed, adapter_data)
post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)
return post_hidden_states, normed_attn_res_output
class FlashGemma3Model(torch.nn.Module):
def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
local_config = copy.deepcopy(config)
local_config.rope_scaling = dict(rope_type="default")
local_rotary_emb = PositionRotaryEmbedding.static(
config=local_config,
dim=config.head_dim,
base=config.rope_local_base_freq,
device=weights.device,
)
global_rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.head_dim,
base=config.rope_theta,
device=weights.device,
)
self.layers = nn.ModuleList(
[
FlashGemma3Layer(
prefix=f"{prefix}.layers.{layer_id}",
config=config,
weights=weights,
layer_id=layer_id,
causal=causal,
is_sliding=bool((layer_id + 1) % config.sliding_window_pattern),
local_rotary_emb=local_rotary_emb,
global_rotary_emb=global_rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
adapter_data: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
# Get rotary cos and sin for this forward
# Avoid to index in each layer
residual = None
for i, layer in enumerate(self.layers):
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = layer.self_attn.rotary_emb.get_cos_sin(position_ids)
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
adapter_data,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashGemma3ForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights, *, causal: bool = True):
super().__init__()
embed_norm = config.hidden_size**0.5
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.embed_tokens.weight *= embed_norm
self.model = FlashGemma3Model(
prefix=prefix, config=config, weights=weights, causal=causal
)
self.lm_head = SpeculativeHead.load(
prefix=(
f"{prefix}.embed_tokens"
if config.tie_word_embeddings
else f"{prefix}.lm_head"
),
config=config,
weights=weights,
)
# self.softcap = config.attn_logit_softcapping
# assert isinstance(self.softcap, float)
self.softcap = None
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
input_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
input_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits
class Gemma3MultimodalInputProjection(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.mm_input_projection_weight = weights.get_tensor(
"multi_modal_projector.mm_input_projection_weight"
)
self.mm_soft_emb_norm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.mm_soft_emb_norm",
weights=weights,
eps=config.vision_config.layer_norm_eps,
)
self.patches_per_image = int(
config.vision_config.image_size // config.vision_config.patch_size
)
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
self.kernel_size = self.patches_per_image // self.tokens_per_side
self.avg_pool = nn.AvgPool2d(
kernel_size=self.kernel_size, stride=self.kernel_size
)
def forward(self, vision_outputs: torch.Tensor):
batch_size, _, seq_length = vision_outputs.shape
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
batch_size, seq_length, self.patches_per_image, self.patches_per_image
)
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
normed_vision_outputs, _ = self.mm_soft_emb_norm(pooled_vision_outputs)
projected_vision_outputs = torch.matmul(
normed_vision_outputs, self.mm_input_projection_weight
)
return projected_vision_outputs.type_as(vision_outputs)
class Gemma3ForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
if config.vision_config is not None:
config.vision_config.quantize = config.quantize
self.post_vision_model_layernorm = nn.LayerNorm.load(
prefix="vision_tower.vision_model.post_layernorm",
weights=weights,
eps=config.vision_config.layer_norm_eps,
)
self.multimodal_projector = Gemma3MultimodalInputProjection(
prefix="multi_modal_projector",
config=config,
weights=weights,
)
text_config = config.text_config
text_config.speculator = config.speculator
text_config.quantize = config.quantize
self.vision_model = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config,
weights=weights,
)
self.text_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,
weights=weights,
)
else:
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
self.text_model = load_text_model(
prefix=prefix,
config=config.text_config,
weights=weights,
)
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
self.dtype = weights.dtype
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
pixel_values = pixel_values.to(dtype=self.dtype)
image_outputs = self.vision_model(pixel_values)
vision_outputs = self.post_vision_model_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multimodal_projector(vision_outputs)
image_features = image_features.view(-1, image_features.shape[-1])
return image_features
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
# Replace the image token embeddings with the vision features
image_token_mask = (input_ids == self.config.image_token_index).to(
input_ids.device
)
inputs_embeds[image_token_mask] = vision_embeds.view(
-1, vision_embeds.shape[-1]
)
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if cu_seqlen_prefill is not None:
position_ids += 1
if attention_mask is not None:
min_dtype = torch.finfo(inputs_embeds.dtype).min
# prefill may be larger than sliding window
effective_seq_len = max(
position_ids.shape[0], self.config.text_config.sliding_window
)
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool),
diagonal=-self.config.text_config.sliding_window,
)
attention_mask_local = torch.where(
sliding_window_mask, min_dtype, attention_mask
)
offset = max(0, position_ids.shape[0] - effective_seq_len)
attention_mask_local = attention_mask_local[
:, :, :, offset : offset + effective_seq_len
]
else:
attention_mask_local = None
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
adapter_data=adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -28,9 +28,9 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
reshape_and_cache,
Seqlen,
HPUPagedAttentionMetadata,
PREFILL_IN_KV_CACHE,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -39,13 +39,11 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class GemmaConfig(PretrainedConfig):
@ -163,12 +161,19 @@ def _load_gqa(config, prefix: str, weights):
class FlashGemmaAttention(torch.nn.Module):
def __init__(self, prefix: str, config, weights, causal: bool, rotary_emb):
def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__()
self.num_heads = config.num_attention_heads
self.head_size = config.head_dim
self.causal = causal
self.rotary_emb = rotary_emb
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
)
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
@ -182,7 +187,6 @@ class FlashGemmaAttention(torch.nn.Module):
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
@ -202,9 +206,10 @@ class FlashGemmaAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
@ -219,36 +224,31 @@ class FlashGemmaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
causal=self.causal,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -293,14 +293,10 @@ class GemmaMLP(nn.Module):
class FlashGemmaLayer(nn.Module):
def __init__(self, prefix: str, config, weights, causal: bool, rotary_emb):
def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__()
self.self_attn = FlashGemmaAttention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
causal=causal,
rotary_emb=rotary_emb,
prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal
)
self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
@ -321,9 +317,10 @@ class FlashGemmaLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -334,9 +331,10 @@ class FlashGemmaLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
# faster post attention rms norm
@ -356,13 +354,6 @@ class FlashGemmaModel(torch.nn.Module):
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.head_dim,
base=config.rope_theta,
device=weights.device,
)
self.layers = nn.ModuleList(
[
FlashGemmaLayer(
@ -370,7 +361,6 @@ class FlashGemmaModel(torch.nn.Module):
config=config,
weights=weights,
causal=causal,
rotary_emb=rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
@ -389,25 +379,20 @@ class FlashGemmaModel(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
adapter_data: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -416,12 +401,11 @@ class FlashGemmaModel(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
@ -462,9 +446,11 @@ class FlashGemmaForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -474,10 +460,10 @@ class FlashGemmaForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -24,12 +24,12 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
reshape_and_cache,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -38,8 +38,6 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
import habana_frameworks.torch as htorch
def load_qkv(config, prefix: str, weights, head_size, num_heads):
@ -49,6 +47,10 @@ def load_qkv(config, prefix: str, weights, head_size, num_heads):
prefix,
weights,
)
elif config.quantize == "marlin":
raise RuntimeError(
"GPT-2 models with marlin quantization are not yet supported"
)
else:
return _load_qkv(config, prefix, weights, head_size, num_heads)
@ -193,7 +195,6 @@ class FlashGPT2Attention(torch.nn.Module):
head_size=self.head_size,
num_heads=self.num_heads,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = load_row(
config,
@ -211,9 +212,10 @@ class FlashGPT2Attention(torch.nn.Module):
hidden_states,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
query, key, value = self.query_key_value(hidden_states).split(
self.head_size * self.num_heads, dim=1
@ -222,35 +224,30 @@ class FlashGPT2Attention(torch.nn.Module):
key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -316,9 +313,10 @@ class FlashGPT2Layer(nn.Module):
residual,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
@ -328,9 +326,10 @@ class FlashGPT2Layer(nn.Module):
hidden_states,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
hidden_states = attn_output + residual
@ -380,33 +379,27 @@ class FlashGPT2Model(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states = self.norm(hidden_states)
@ -439,9 +432,11 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -453,9 +448,12 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta=hpu_attention_meta,
max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -24,13 +24,12 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
reshape_and_cache,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -39,17 +38,13 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
apply_rotary_pos_emb,
)
import habana_frameworks.torch as htorch
def load_attention(config, prefix: str, weights):
@ -83,25 +78,39 @@ class GPTJRotary(PositionRotaryEmbedding):
cos: torch.Tensor,
sin: torch.Tensor,
):
num_tokens = query.shape[0]
head_size = query.shape[-1]
rope_mode = RotaryPosEmbeddingMode.PAIRWISE
sin = torch.repeat_interleave(sin, 2, dim=-1)
cos = torch.repeat_interleave(cos, 2, dim=-1)
rotary_dim = cos.shape[-1]
query_shape = query.shape
query = query.view(num_tokens, -1, head_size)
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
# Such controlflows may add some overhead.
if SYSTEM == "cuda":
import rotary_emb
key_shape = key.shape
key = key.view(num_tokens, -1, head_size)
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
q1 = query[..., ::2]
q2 = query[..., 1::2]
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
k1 = key[..., ::2]
k2 = key[..., 1::2]
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif SYSTEM == "rocm":
from vllm._C import ops
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
head_size = query.shape[-1]
# Inplace operation, updating query and key.
ops.rotary_embedding(query, key, head_size, cos, sin, False)
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
ipex.llm.functional.rotary_embedding(
query, key, sin, cos, query.size(-1), False
)
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)
class FlashGPTJAttention(torch.nn.Module):
@ -110,7 +119,6 @@ class FlashGPTJAttention(torch.nn.Module):
prefix: str,
config,
weights,
rotary_emb,
):
super().__init__()
self.num_heads = config.num_attention_heads
@ -132,7 +140,6 @@ class FlashGPTJAttention(torch.nn.Module):
prefix=prefix,
weights=weights,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = load_row(
config,
@ -144,7 +151,13 @@ class FlashGPTJAttention(torch.nn.Module):
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
self.rotary_emb = rotary_emb
self.rotary_emb = GPTJRotary.static(
config=config,
dim=self.rotary_dim,
base=10000,
device=weights.device,
)
def forward(
self,
@ -153,9 +166,10 @@ class FlashGPTJAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
query, key, value = self.query_key_value(hidden_states).split(
self.head_size * self.num_heads, dim=1
@ -172,35 +186,30 @@ class FlashGPTJAttention(torch.nn.Module):
else:
self.rotary_emb(query, key, cos, sin)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -239,13 +248,10 @@ class GPTJMLP(nn.Module):
class FlashGPTJLayer(nn.Module):
def __init__(self, prefix: str, config, weights, rotary_emb):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.self_attn = FlashGPTJAttention(
prefix=f"{prefix}.attn",
config=config,
weights=weights,
rotary_emb=rotary_emb,
prefix=f"{prefix}.attn", config=config, weights=weights
)
self.mlp = GPTJMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
@ -261,9 +267,10 @@ class FlashGPTJLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention
@ -273,9 +280,10 @@ class FlashGPTJLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
feed_forward_hidden_states = self.mlp(hidden_states)
@ -289,12 +297,6 @@ class FlashGPTJModel(torch.nn.Module):
self.config = config
self.wte = TensorParallelEmbedding(prefix=f"{prefix}.wte", weights=weights)
rotary_emb = GPTJRotary.static(
config=config,
dim=config.rotary_dim,
base=10000,
device=weights.device,
)
self.layers = nn.ModuleList(
[
FlashGPTJLayer(
@ -303,7 +305,6 @@ class FlashGPTJModel(torch.nn.Module):
),
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
@ -326,24 +327,21 @@ class FlashGPTJModel(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.wte(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -352,12 +350,11 @@ class FlashGPTJModel(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.ln_f(hidden_states, residual)
@ -384,9 +381,11 @@ class FlashGPTJForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -395,9 +394,11 @@ class FlashGPTJForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta=hpu_attention_meta,
max_s,
prefill_cache_indices=prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -26,18 +26,15 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
import habana_frameworks.torch as htorch
from text_generation_server.layers.attention import (
KVCache,
get_kv_scales,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
reshape_and_cache,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -60,6 +57,15 @@ from text_generation_server.utils.weights import (
)
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
if SYSTEM != "ipex":
pass
if SYSTEM == "rocm":
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
def load_attention(config, prefix: str, weights, layer_id):
# Only defined in granite.
@ -133,7 +139,6 @@ class FlashLlamaAttention(torch.nn.Module):
prefix: str,
config,
weights,
rotary_emb,
):
super().__init__()
self.num_heads = config.num_attention_heads
@ -145,14 +150,15 @@ class FlashLlamaAttention(torch.nn.Module):
config.num_key_value_heads = getattr(
config, "num_key_value_heads", config.num_attention_heads
)
self.rotary_emb = rotary_emb
# `config.attention_multiplier` is used in Granite
self.softmax_scale = getattr(
config, "attention_multiplier", self.head_size**-0.5
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
)
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
@ -171,13 +177,11 @@ class FlashLlamaAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights, index)
self.index = index
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=getattr(config, "attention_bias", False),
bias=False,
)
self.o_proj = TensorParallelAdapterRowLinear.load(
@ -198,11 +202,12 @@ class FlashLlamaAttention(torch.nn.Module):
cos,
sin,
cu_seqlen_prefill,
kv_cache: KVCache,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
@ -217,35 +222,30 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_scales=self.kv_scales,
kv_cache=kv_cache,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
return self.o_proj(
@ -363,15 +363,35 @@ class LlamaMLP(nn.Module):
self.hidden_size = config.hidden_size
def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
and hidden_states.dtype == torch.float16
and hidden_states.shape[0] == 1
and not self.quantize
and self.hidden_size
!= 16384 # TODO: Temporary workaround for `LLMM_Silu` kernel not working with LLama3.1 405B; needs refactoring once fixed.
):
out = torch.empty(
hidden_states.shape[0],
self.intermediate_size,
dtype=hidden_states.dtype,
device="cuda",
)
_custom_C.LLMM_Silu(
self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8
)
return self.down_proj(out, adapter_data)
else:
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class FlashLlamaLayer(nn.Module):
def __init__(self, index, prefix, config, weights, rotary_emb):
def __init__(self, index, prefix, config, weights):
super().__init__()
with no_fp8(weights):
@ -380,7 +400,6 @@ class FlashLlamaLayer(nn.Module):
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
if config.model_type == "phimoe":
@ -389,7 +408,7 @@ class FlashLlamaLayer(nn.Module):
if SparseMoELayer.is_supported(weights)
else DenseMoELayer
)
self.mlp = Phi3MoE(
self.dense = Phi3MoE(
f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
)
# with moe the layernorms are are not rmsnorms and they have bias
@ -404,7 +423,7 @@ class FlashLlamaLayer(nn.Module):
eps=config.rms_norm_eps,
)
else:
self.mlp = LlamaMLP(
self.dense = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
)
self.input_layernorm = FastRMSNorm.load(
@ -418,11 +437,6 @@ class FlashLlamaLayer(nn.Module):
eps=config.rms_norm_eps,
)
# Used in Granite
# This could eventually be baked into the weights like we do for the embeddings/lm_head
# but this would mean modifying the lora code
self.residual_multiplier = getattr(config, "residual_multiplier", None)
def forward(
self,
hidden_states,
@ -431,11 +445,12 @@ class FlashLlamaLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
cross_attention_states,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -446,21 +461,19 @@ class FlashLlamaLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
hpu_attention_meta=hpu_attention_meta,
)
if self.residual_multiplier is not None:
attn_output *= self.residual_multiplier
# faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
)
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
if self.residual_multiplier is not None:
mlp_output *= self.residual_multiplier
mlp_output = self.dense(normed_attn_res_output, adapter_data)
return mlp_output, attn_res
@ -476,35 +489,33 @@ class FlashLlamaModel(torch.nn.Module):
# Skip fp8 quant for first and last layers
self.layers = nn.ModuleList()
self.cross_attention_layers = getattr(config, "cross_attention_layers", [])
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.hidden_size // config.num_attention_heads,
base=config.rope_theta,
device=weights.device,
)
with no_fp8(weights):
self.layers.append(
FlashLlamaLayer(
index=0,
prefix=f"{prefix}.layers.0",
prefix=(
"model.layers.0" if not prefix else f"{prefix}.model.layers.0"
),
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
)
# Skip first and last layers
for layer_id in range(1, config.num_hidden_layers - 1):
if layer_id in self.cross_attention_layers:
from text_generation_server.models.custom_modeling.flash_mllama import (
from text_generation_server.models.custom_modeling.mllama import (
FlashLlamaCrossLayer,
)
self.layers.append(
FlashLlamaCrossLayer(
index=layer_id,
prefix=(f"{prefix}.layers.{layer_id}"),
prefix=(
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}.model.layers.{layer_id}"
),
config=config,
weights=weights,
)
@ -513,10 +524,13 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append(
FlashLlamaLayer(
index=layer_id,
prefix=(f"{prefix}.layers.{layer_id}"),
prefix=(
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}.model.layers.{layer_id}"
),
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
)
@ -525,15 +539,18 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append(
FlashLlamaLayer(
index=last_layer_id,
prefix=(f"{prefix}.layers.{last_layer_id}"),
prefix=(
f"model.layers.{last_layer_id}"
if not prefix
else f"{prefix}.model.layers.{last_layer_id}"
),
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
)
self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm",
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
weights=weights,
eps=config.rms_norm_eps,
)
@ -550,27 +567,24 @@ class FlashLlamaModel(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
adapter_data,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
cross_attention_states=None,
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -579,14 +593,13 @@ class FlashLlamaModel(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
max_s,
adapter_data,
cross_attention_states,
hpu_attention_meta=hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
@ -594,60 +607,42 @@ class FlashLlamaModel(torch.nn.Module):
class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights, name=None):
if name is None:
name = "model"
def __init__(self, prefix: str, config, weights):
super().__init__()
with no_fp8(weights):
self.embed_tokens = TensorParallelEmbedding(
prefix=(
f"{name}.embed_tokens"
"model.embed_tokens"
if not prefix
else f"{prefix}.{name}.embed_tokens"
else f"{prefix}.model.embed_tokens"
),
weights=weights,
)
self.model = FlashLlamaModel(
prefix=name if not prefix else f"{prefix}.{name}",
config=config,
weights=weights,
)
self.model = FlashLlamaModel(prefix, config, weights)
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
else:
suffix = "lm_head"
# Used in Granite
embedding_multiplier = getattr(config, "embedding_multiplier", None)
if embedding_multiplier is not None:
self.embed_tokens.weight.data *= embedding_multiplier
prefix = suffix if not prefix or name != "model" else f"{prefix}.{suffix}"
with no_fp8(weights):
self.lm_head = SpeculativeHead.load(
config,
prefix,
weights,
prefix=suffix if not prefix else f"{prefix}.{suffix}",
weights=weights,
)
# Used in Granite
self.logits_scaling = getattr(config, "logits_scaling", None)
if self.logits_scaling is not None and self.lm_head.head is not None:
try:
# Scale the weights directly
self.lm_head.head.linear.weight.data /= self.logits_scaling
self.logits_scaled = True
except Exception:
self.logits_scaled = False
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states=None,
@ -658,20 +653,16 @@ class FlashLlamaForCausalLM(torch.nn.Module):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
adapter_data=adapter_data,
cross_attention_states=cross_attention_states,
hpu_attention_meta=hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
# Used in Granite
if self.logits_scaling is not None and not self.logits_scaled:
logits /= self.logits_scaling
if speculative_logits is not None:
speculative_logits /= self.logits_scaling
return logits, speculative_logits

View File

@ -1,298 +0,0 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Llava-NeXT model."""
from typing import List, Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (`tuple`):
The size of the input image in the format (height, width).
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (height, width).
"""
if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists")
height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (`torch.Tensor`):
The image tensor, assumed to be of shape (num_channels, height, width).
original_size (`tuple`):
The original size of the image (height, width).
Returns:
`torch.Tensor`: The unpadded image tensor.
"""
original_height, original_width = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
return unpadded_tensor
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
class LlavaNextMultiModalProjector(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.linear_1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class FlashLlavaNextForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
vision_config = config.vision_config
# Instead of selecting in hidden_states[-2].
# Instead compute only the n -2 + 1 layers and don't pool
if config.vision_feature_layer < 0:
vision_config.num_hidden_layers += config.vision_feature_layer + 1
else:
vision_config.num_hidden_layers = config.vision_feature_layer + 1
self.vision_tower = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config,
weights=weights,
)
self.multi_modal_projector = LlavaNextMultiModalProjector(
prefix="multi_modal_projector", config=config, weights=weights
)
self.image_newline = weights.get_tensor("image_newline")
self.vocab_size = config.text_config.vocab_size
self.config = config
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
self.text_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,
weights=weights,
)
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
def _merge_input_ids_with_image_features(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
mask = torch.where(input_ids == self.config.image_token_index)
# Let's pray we have enabled enough slots !
try:
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
except Exception as e:
raise RuntimeError(
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
)
return inputs_embeds
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
# 1. Extract the input embeddings
# 2. Merge text and images
num_images, num_patches, channels, height, width = pixel_values.shape
pixel_values = pixel_values.view(
num_images * num_patches, channels, height, width
)
image_features = self.vision_tower(pixel_values)
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
# Already done within the clip model
selected_image_feature = image_features.last_hidden_state
if self.config.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.config.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise RuntimeError(
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
)
image_features = self.multi_modal_projector(selected_image_feature)
# split up image_features for each of the individual images
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [num_patches] * num_images
image_features = torch.split(image_features, split_sizes, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = (
self.config.vision_config.image_size // self.config.vision_config.patch_size
)
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
if height * width != base_image_feature.shape[0]:
raise ValueError(
"The number of patches is not consistent with the image size."
)
# Dimensions are intentionally swapped to be bug-compatible with
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature, self.image_newline[None]), dim=0
)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
return image_features.view(-1, image_features.shape[-1])
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, vision_embeds
)
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
):
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
adapter_data=adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -26,13 +26,12 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
reshape_and_cache,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -42,11 +41,18 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
import habana_frameworks.torch as htorch
if SYSTEM == "rocm":
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
class MistralConfig(PretrainedConfig):
@ -104,20 +110,24 @@ class MistralConfig(PretrainedConfig):
class MistralAttention(torch.nn.Module):
def __init__(self, prefix: str, config, weights, layer_id, rotary_emb):
def __init__(self, prefix: str, config, weights, layer_id):
super().__init__()
self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
if getattr(config, "head_dim", None) is not None:
if hasattr(config, "head_dim"):
self.head_size = config.head_dim
else:
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = rotary_emb
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
)
self.softmax_scale = self.head_size**-0.5
@ -150,7 +160,6 @@ class MistralAttention(torch.nn.Module):
],
process_group=weights.process_group,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load(
config,
@ -176,10 +185,12 @@ class MistralAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
@ -194,37 +205,38 @@ class MistralAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
window_size_left=self.max_past,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
window_size_left=self.max_past,
max_s,
)
return self.o_proj(
@ -288,22 +300,39 @@ class MistralMLP(nn.Module):
self.quantize = config.quantize
def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
and hidden_states.dtype == torch.float16
and hidden_states.shape[0] == 1
and not self.quantize
):
out = torch.empty(
hidden_states.shape[0],
self.intermediate_size,
dtype=hidden_states.dtype,
device="cuda",
)
_custom_C.LLMM_Silu(
self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8
)
return self.down_proj(out, adapter_data)
else:
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class MistralLayer(nn.Module):
def __init__(self, prefix: str, config, weights, layer_id, rotary_emb):
def __init__(self, prefix: str, config, weights, layer_id):
super().__init__()
self.self_attn = MistralAttention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
layer_id=layer_id,
rotary_emb=rotary_emb,
)
self.mlp = MistralMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
@ -326,10 +355,12 @@ class MistralLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -340,10 +371,12 @@ class MistralLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
)
# faster post attention rms norm
@ -363,19 +396,6 @@ class MistralModel(torch.nn.Module):
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if getattr(config, "head_dim", None) is not None:
head_dim = config.head_dim
else:
head_dim = config.hidden_size // config.num_attention_heads
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=head_dim,
base=config.rope_theta,
device=weights.device,
)
self.layers = nn.ModuleList(
[
MistralLayer(
@ -383,7 +403,6 @@ class MistralModel(torch.nn.Module):
config=config,
weights=weights,
layer_id=layer_id,
rotary_emb=rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
@ -404,24 +423,22 @@ class MistralModel(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
adapter_data: Optional[torch.Tensor] = None,
):
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, true_max_s, hidden_states.dtype
)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -430,13 +447,13 @@ class MistralModel(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
@ -481,21 +498,35 @@ class FlashMistralForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
inputs_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
true_max_s,
prefill_cache_indices,
adapter_data,
)
if lm_head_indices is not None:

View File

@ -37,15 +37,13 @@ from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
set_block_mapping,
HPUPagedAttentionMetadata,
reshape_and_cache,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class MixtralConfig(PretrainedConfig):
@ -188,7 +186,6 @@ class MixtralAttention(torch.nn.Module):
prefix: str,
config,
weights,
rotary_emb,
):
super().__init__()
self.max_past = (
@ -197,7 +194,13 @@ class MixtralAttention(torch.nn.Module):
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = rotary_emb
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
)
self.softmax_scale = self.head_size**-0.5
@ -212,7 +215,6 @@ class MixtralAttention(torch.nn.Module):
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
@ -232,9 +234,11 @@ class MixtralAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
prefill_cache_indices,
):
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
@ -249,36 +253,38 @@ class MixtralAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
window_size_left=self.max_past,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -340,15 +346,12 @@ class MixtralMoE(nn.Module):
class MixtralLayer(nn.Module):
def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):
def __init__(self, prefix: str, layer_id, config, weights):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = MixtralAttention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
rotary_emb=rotary_emb,
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
moe_layer_cls = (
@ -375,9 +378,11 @@ class MixtralLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
prefill_cache_indices,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -388,9 +393,11 @@ class MixtralLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
prefill_cache_indices,
)
# faster post attention rms norm
@ -414,12 +421,6 @@ class MixtralModel(torch.nn.Module):
weights=weights,
)
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.hidden_size // config.num_attention_heads,
base=config.rope_theta,
device=weights.device,
)
self.layers = nn.ModuleList(
[
MixtralLayer(
@ -427,7 +428,6 @@ class MixtralModel(torch.nn.Module):
layer_id,
config,
weights,
rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
@ -448,24 +448,22 @@ class MixtralModel(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, true_max_s, hidden_states.dtype
)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -474,12 +472,12 @@ class MixtralModel(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
prefill_cache_indices,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
@ -509,20 +507,34 @@ class FlashMixtralForCausalLM(torch.nn.Module):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
true_max_s,
prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -29,9 +29,8 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
reshape_and_cache,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -40,7 +39,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
@ -48,7 +47,6 @@ from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class GPTNeoXConfig(TransformersGPTNeoXConfig):
@ -99,7 +97,7 @@ def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
class FlashNeoxAttention(torch.nn.Module):
def __init__(self, config, prefix, weights, rotary_emb):
def __init__(self, config, prefix, weights):
super().__init__()
num_heads = config.num_attention_heads
hidden_size = config.hidden_size
@ -116,7 +114,14 @@ class FlashNeoxAttention(torch.nn.Module):
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.rotary_emb = rotary_emb
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.rotary_dim,
base=config.rotary_emb_base,
device=weights.device,
)
self.softmax_scale = self.head_size ** (-0.5)
self.query_key_value = load_qkv(
@ -127,7 +132,6 @@ class FlashNeoxAttention(torch.nn.Module):
head_size=self.head_size,
hidden_size=self.hidden_size,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=True
)
@ -142,9 +146,10 @@ class FlashNeoxAttention(torch.nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
@ -160,35 +165,30 @@ class FlashNeoxAttention(torch.nn.Module):
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
kv_cache.store(
key=qkv[:, 1],
value=qkv[:, 2],
slots=slots,
kv_scales=self.kv_scales,
)
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
# flash attention
attn_output = attention(
query=qkv[:, 0],
key=qkv[:, 1],
value=qkv[:, 2],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
qkv[:, 0],
kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1],
kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2],
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
qkv[:, 0],
kv_cache,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
max_s,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -224,7 +224,7 @@ class FlashMLP(nn.Module):
class FlashNeoXLayer(nn.Module):
def __init__(self, layer_id, config, weights, rotary_emb):
def __init__(self, layer_id, config, weights):
super().__init__()
layer_norm_eps = config.layer_norm_eps
@ -241,10 +241,7 @@ class FlashNeoXLayer(nn.Module):
eps=layer_norm_eps,
)
self.attention = FlashNeoxAttention(
config,
prefix=f"{prefix}.attention",
weights=weights,
rotary_emb=rotary_emb,
config, prefix=f"{prefix}.attention", weights=weights
)
self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights)
@ -258,9 +255,10 @@ class FlashNeoXLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
):
if self.use_parallel_residual:
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
@ -271,9 +269,10 @@ class FlashNeoXLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
@ -294,9 +293,10 @@ class FlashNeoXLayer(nn.Module):
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
hidden_states, residual = self.post_attention_layernorm(
@ -324,18 +324,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
prefix=f"{prefix}.embed_in", weights=weights
)
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=int(
config.rotary_pct * (config.hidden_size // config.num_attention_heads)
),
base=config.rotary_emb_base,
device=weights.device,
)
self.layers = nn.ModuleList(
[
FlashNeoXLayer(layer_id, config, weights, rotary_emb)
FlashNeoXLayer(layer_id, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
@ -356,24 +347,20 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_in(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids)
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -382,12 +369,11 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
@ -415,9 +401,11 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
@ -426,9 +414,10 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
hpu_attention_meta,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -19,7 +19,7 @@ from torch import nn
from typing import Optional, List, Tuple
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
@ -62,63 +62,54 @@ class PaliGemmaForConditionalGeneration(nn.Module):
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
self.dtype = weights.dtype
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
pixel_values = pixel_values.to(dtype=self.dtype)
image_outputs = self.vision_tower(pixel_values)
last_hidden_state = self.post_vision_tower_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multi_modal_projector(last_hidden_state)
image_features = image_features.view(-1, image_features.shape[-1])
return image_features
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
mask = input_ids == self.config.image_token_index
inputs_embeds[mask] = vision_embeds
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused here
pixel_attention_mask: Optional[torch.BoolTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids)
# TODO This is odd but apparently pali gemma position ids start at 1.
if cu_seqlen_prefill is not None:
max_s += 1
position_ids += 1
if pixel_values is not None:
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
image_outputs = self.vision_tower(pixel_values)
last_hidden_state = self.post_vision_tower_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multi_modal_projector(last_hidden_state)
# mask where image or padding tokens
mask = input_ids == self.config.image_token_index
# insert image features into input embeddings
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
adapter_data=adapter_data,
max_s=max_s,
)
if lm_head_indices is not None:

Some files were not shown because too many files have changed in this diff Show More