Merge branch 'main' into add_logs_gaudi_warmup

This commit is contained in:
regisss 2025-06-16 12:36:45 +00:00
commit 2ba396c4c1
144 changed files with 7340 additions and 8680 deletions

View File

@ -21,7 +21,7 @@ jobs:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
with:
name: text-generation-inference
name: huggingface
# If you chose signing key for write access
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
env:

View File

@ -20,7 +20,7 @@ jobs:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
with:
name: text-generation-inference
name: huggingface
# If you chose signing key for write access
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
env:

View File

@ -25,7 +25,7 @@ jobs:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
with:
name: text-generation-inference
name: huggingface
# If you chose signing key for write access
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
env:

16
Cargo.lock generated
View File

@ -4650,7 +4650,7 @@ dependencies = [
[[package]]
name = "text-generation-backends-trtllm"
version = "3.3.0-dev0"
version = "3.3.2-dev0"
dependencies = [
"async-trait",
"clap 4.5.32",
@ -4671,7 +4671,7 @@ dependencies = [
[[package]]
name = "text-generation-benchmark"
version = "3.3.0-dev0"
version = "3.3.2-dev0"
dependencies = [
"average",
"clap 4.5.32",
@ -4691,7 +4691,7 @@ dependencies = [
[[package]]
name = "text-generation-client"
version = "3.3.0-dev0"
version = "3.3.2-dev0"
dependencies = [
"async-trait",
"base64 0.22.1",
@ -4709,7 +4709,7 @@ dependencies = [
[[package]]
name = "text-generation-launcher"
version = "3.3.0-dev0"
version = "3.3.2-dev0"
dependencies = [
"clap 4.5.32",
"ctrlc",
@ -4730,7 +4730,7 @@ dependencies = [
[[package]]
name = "text-generation-router"
version = "3.3.0-dev0"
version = "3.3.2-dev0"
dependencies = [
"anyhow",
"async-stream",
@ -4782,7 +4782,7 @@ dependencies = [
[[package]]
name = "text-generation-router-llamacpp"
version = "3.3.0-dev0"
version = "3.3.2-dev0"
dependencies = [
"async-trait",
"bindgen 0.71.1",
@ -4800,7 +4800,7 @@ dependencies = [
[[package]]
name = "text-generation-router-v2"
version = "3.3.0-dev0"
version = "3.3.2-dev0"
dependencies = [
"async-stream",
"async-trait",
@ -4849,7 +4849,7 @@ dependencies = [
[[package]]
name = "text-generation-router-v3"
version = "3.3.0-dev0"
version = "3.3.2-dev0"
dependencies = [
"async-stream",
"async-trait",

View File

@ -21,7 +21,7 @@ default-members = [
resolver = "2"
[workspace.package]
version = "3.3.0-dev0"
version = "3.3.2-dev0"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -48,7 +48,7 @@ FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
WORKDIR /usr/src/
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
ARG PYTORCH_VERSION=2.6
ARG PYTORCH_VERSION=2.7
ARG PYTHON_VERSION=3.11
# Keep in sync with `server/pyproject.toml
@ -121,13 +121,6 @@ COPY server/Makefile-awq Makefile
# Build specific version of transformers
RUN . .venv/bin/activate && make build-awq
# Build Lorax Punica kernels
FROM kernel-builder AS lorax-punica-builder
WORKDIR /usr/src
COPY server/Makefile-lorax-punica Makefile
# Build specific version of transformers
RUN . .venv/bin/activate && TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
# Build Transformers CUDA kernels
FROM kernel-builder AS custom-kernels-builder
WORKDIR /usr/src
@ -210,8 +203,6 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from awq kernels builder
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from lorax punica kernels builder
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from mamba builder
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages

View File

@ -5,7 +5,7 @@ RUN mkdir -p /tgi
# Fetch the optimum-neuron sources directly to avoid relying on pypi deployments
FROM alpine AS optimum-neuron
RUN mkdir -p /optimum-neuron
ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.1.0.tar.gz /optimum-neuron/sources.tar.gz
ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.2.0.tar.gz /optimum-neuron/sources.tar.gz
RUN tar -C /optimum-neuron -xf /optimum-neuron/sources.tar.gz --strip-components=1
# Build cargo components (adapted from TGI original Dockerfile)
@ -108,10 +108,10 @@ RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEU
# Install neuronx packages
RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \
aws-neuronx-dkms=2.19.64.0 \
aws-neuronx-collectives=2.23.135.0-3e70920f2 \
aws-neuronx-runtime-lib=2.23.112.0-9b5179492 \
aws-neuronx-tools=2.20.204.0 \
aws-neuronx-dkms=2.20.28.0 \
aws-neuronx-collectives=2.24.59.0-838c7fc8b \
aws-neuronx-runtime-lib=2.24.53.0-f239092cc \
aws-neuronx-tools=2.22.61.0 \
libxml2 \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
@ -125,11 +125,10 @@ RUN pip3 install \
--index-url https://download.pytorch.org/whl/cpu
RUN pip3 install \
neuronx-cc==2.16.372.0 \
torch-neuronx==2.5.1.2.4.0 \
transformers-neuronx==0.13.322 \
neuronx-distributed==0.10.1 \
libneuronxla==2.1.681.0 \
neuronx-cc==2.17.194.0 \
torch-neuronx==2.5.1.2.6.0 \
neuronx-distributed==0.11.0 \
libneuronxla==2.2.1630.0 \
--extra-index-url=https://pip.repos.neuron.amazonaws.com
# Install HuggingFace packages
@ -160,7 +159,7 @@ RUN pip install dist/text_generation_server*.tar.gz
# Final image
FROM neuron
COPY backends/neuron/tgi_env.py /tgi_env.py
COPY backends/neuron/tgi_entry_point.py /tgi_entry_point.py
COPY backends/neuron/tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh

View File

@ -6,7 +6,7 @@
FROM nixos/nix:2.18.8 AS builder
RUN echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf
RUN nix profile install nixpkgs#cachix
RUN cachix use text-generation-inference
RUN cachix use huggingface
WORKDIR /root
ADD . .
RUN nix build .

View File

@ -1,5 +1,5 @@
# Those arguments are required to build the image
ARG HABANA_VERSION=1.20.0
ARG HABANA_VERSION=1.21.0
ARG PYTORCH_VERSION=2.6.0
# Rust builder
@ -57,9 +57,12 @@ ARG PYTORCH_VERSION
FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base
ENV ATTENTION=default
ENV ATTENTION=paged
ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0
ENV PT_HPU_LAZY_MODE=1
ENV PT_HPU_WEIGHT_SHARING=0
ENV VLLM_EXPONENTIAL_BUCKETING=true
# Text Generation Inference base env
ENV HF_HOME=/data \
@ -95,7 +98,9 @@ RUN cd server && \
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
pip install . --no-cache-dir
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix
RUN pip install compressed-tensors==0.9.1
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router

View File

@ -84,7 +84,7 @@ model=HuggingFaceH4/zephyr-7b-beta
volume=$PWD/data
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model
ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model
```
And then you can make requests like
@ -121,7 +121,7 @@ curl localhost:8080/v1/chat/completions \
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0-rocm --model-id $model` instead of the command above.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2-rocm --model-id $model` instead of the command above.
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
```
@ -152,7 +152,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
token=<your cli READ token>
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model
ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model
```
### A note on Shared Memory (shm)
@ -256,7 +256,7 @@ Another option is to install `text-generation-inference` locally using [Nix](htt
we only support Nix on x86_64 Linux with CUDA GPUs. When using Nix, all dependencies can
be pulled from a binary cache, removing the need to build them locally.
First follow the instructions to [install Cachix and enable the TGI cache](https://app.cachix.org/cache/text-generation-inference).
First follow the instructions to [install Cachix and enable the Hugging Face cache](https://app.cachix.org/cache/huggingface).
Setting up the cache is important, otherwise Nix will build many of the dependencies
locally, which can take hours.

View File

@ -2,7 +2,7 @@ mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
mkfile_dir := $(dir $(mkfile_path))
root_dir := ${mkfile_dir}/../..
HABANA_VERSION := 1.20.0
HABANA_VERSION := 1.21.0
PYTORCH_VERSION := 2.6.0
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
@ -50,6 +50,7 @@ local-dev-install: install-dependencies
# In order to run the integration tests, you need to first build the image (make -C backends/gaudi image)
run-integration-tests:
pip install -U pip uv
uv pip install -r ${root_dir}/backends/gaudi/server/integration-tests/requirements.txt
DOCKER_VOLUME=${root_dir}/data \
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
@ -57,6 +58,7 @@ run-integration-tests:
# This is used to capture the expected outputs for the integration tests offering an easy way to add more models to the integration tests
capture-expected-outputs-for-integration-tests:
pip install -U pip uv
DOCKER_VOLUME=${root_dir}/data \
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
uv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests/capture_expected_outputs.py

View File

@ -9,8 +9,8 @@ TEST_CONFIGS = {
"meta-llama/Llama-3.1-8B-Instruct-shared": {
"model_id": "meta-llama/Llama-3.1-8B-Instruct",
"input": "What is Deep Learning?",
"expected_greedy_output": " A Beginners Guide\nDeep learning is a subset of machine learning that involves the use",
"expected_batch_output": " A Beginners Guide\nDeep learning is a subset of machine learning that involves the use",
"expected_greedy_output": " A Beginners Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
"expected_batch_output": " A Beginners Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
"args": [
"--sharded",
"true",
@ -165,20 +165,6 @@ TEST_CONFIGS = {
"4",
],
},
"facebook/opt-125m": {
"model_id": "facebook/opt-125m",
"input": "What is Deep Learning?",
"expected_greedy_output": "\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout",
"expected_batch_output": "\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
],
},
"EleutherAI/gpt-j-6b": {
"model_id": "EleutherAI/gpt-j-6b",
"input": "What is Deep Learning?",

View File

@ -1058,199 +1058,6 @@ files = [
{file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"},
]
[[package]]
name = "nvidia-cublas-cu12"
version = "12.4.5.8"
description = "CUBLAS native runtime libraries"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3"},
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b"},
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-win_amd64.whl", hash = "sha256:5a796786da89203a0657eda402bcdcec6180254a8ac22d72213abc42069522dc"},
]
[[package]]
name = "nvidia-cuda-cupti-cu12"
version = "12.4.127"
description = "CUDA profiling tools runtime libs."
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a"},
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb"},
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922"},
]
[[package]]
name = "nvidia-cuda-nvrtc-cu12"
version = "12.4.127"
description = "NVRTC native runtime libraries"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0eedf14185e04b76aa05b1fea04133e59f465b6f960c0cbf4e37c3cb6b0ea198"},
{file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338"},
{file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:a961b2f1d5f17b14867c619ceb99ef6fcec12e46612711bcec78eb05068a60ec"},
]
[[package]]
name = "nvidia-cuda-runtime-cu12"
version = "12.4.127"
description = "CUDA Runtime native Libraries"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3"},
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5"},
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:09c2e35f48359752dfa822c09918211844a3d93c100a715d79b59591130c5e1e"},
]
[[package]]
name = "nvidia-cudnn-cu12"
version = "9.1.0.70"
description = "cuDNN runtime libraries"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f"},
{file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a"},
]
[package.dependencies]
nvidia-cublas-cu12 = "*"
[[package]]
name = "nvidia-cufft-cu12"
version = "11.2.1.3"
description = "CUFFT native runtime libraries"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399"},
{file = "nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9"},
{file = "nvidia_cufft_cu12-11.2.1.3-py3-none-win_amd64.whl", hash = "sha256:d802f4954291101186078ccbe22fc285a902136f974d369540fd4a5333d1440b"},
]
[package.dependencies]
nvidia-nvjitlink-cu12 = "*"
[[package]]
name = "nvidia-curand-cu12"
version = "10.3.5.147"
description = "CURAND native runtime libraries"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1f173f09e3e3c76ab084aba0de819c49e56614feae5c12f69883f4ae9bb5fad9"},
{file = "nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b"},
{file = "nvidia_curand_cu12-10.3.5.147-py3-none-win_amd64.whl", hash = "sha256:f307cc191f96efe9e8f05a87096abc20d08845a841889ef78cb06924437f6771"},
]
[[package]]
name = "nvidia-cusolver-cu12"
version = "11.6.1.9"
description = "CUDA solver native runtime libraries"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e"},
{file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260"},
{file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-win_amd64.whl", hash = "sha256:e77314c9d7b694fcebc84f58989f3aa4fb4cb442f12ca1a9bde50f5e8f6d1b9c"},
]
[package.dependencies]
nvidia-cublas-cu12 = "*"
nvidia-cusparse-cu12 = "*"
nvidia-nvjitlink-cu12 = "*"
[[package]]
name = "nvidia-cusparse-cu12"
version = "12.3.1.170"
description = "CUSPARSE native runtime libraries"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3"},
{file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1"},
{file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-win_amd64.whl", hash = "sha256:9bc90fb087bc7b4c15641521f31c0371e9a612fc2ba12c338d3ae032e6b6797f"},
]
[package.dependencies]
nvidia-nvjitlink-cu12 = "*"
[[package]]
name = "nvidia-cusparselt-cu12"
version = "0.6.2"
description = "NVIDIA cuSPARSELt"
optional = false
python-versions = "*"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:067a7f6d03ea0d4841c85f0c6f1991c5dda98211f6302cb83a4ab234ee95bef8"},
{file = "nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:df2c24502fd76ebafe7457dbc4716b2fec071aabaed4fb7691a201cde03704d9"},
{file = "nvidia_cusparselt_cu12-0.6.2-py3-none-win_amd64.whl", hash = "sha256:0057c91d230703924c0422feabe4ce768841f9b4b44d28586b6f6d2eb86fbe70"},
]
[[package]]
name = "nvidia-nccl-cu12"
version = "2.21.5"
description = "NVIDIA Collective Communication Library (NCCL) Runtime"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0"},
]
[[package]]
name = "nvidia-nvjitlink-cu12"
version = "12.4.127"
description = "Nvidia JIT LTO Library"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"},
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"},
]
[[package]]
name = "nvidia-nvtx-cu12"
version = "12.4.127"
description = "NVIDIA Tools Extension"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7959ad635db13edf4fc65c06a6e9f9e55fc2f92596db928d169c0bb031e88ef3"},
{file = "nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a"},
{file = "nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485"},
]
[[package]]
name = "opentelemetry-api"
version = "1.32.0"
@ -2650,63 +2457,6 @@ files = [
{file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"},
]
[[package]]
name = "torch"
version = "2.6.0"
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
optional = false
python-versions = ">=3.9.0"
groups = ["main"]
files = [
{file = "torch-2.6.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:6860df13d9911ac158f4c44031609700e1eba07916fff62e21e6ffa0a9e01961"},
{file = "torch-2.6.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:c4f103a49830ce4c7561ef4434cc7926e5a5fe4e5eb100c19ab36ea1e2b634ab"},
{file = "torch-2.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:56eeaf2ecac90da5d9e35f7f35eb286da82673ec3c582e310a8d1631a1c02341"},
{file = "torch-2.6.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:09e06f9949e1a0518c5b09fe95295bc9661f219d9ecb6f9893e5123e10696628"},
{file = "torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:7979834102cd5b7a43cc64e87f2f3b14bd0e1458f06e9f88ffa386d07c7446e1"},
{file = "torch-2.6.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:ccbd0320411fe1a3b3fec7b4d3185aa7d0c52adac94480ab024b5c8f74a0bf1d"},
{file = "torch-2.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:46763dcb051180ce1ed23d1891d9b1598e07d051ce4c9d14307029809c4d64f7"},
{file = "torch-2.6.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:94fc63b3b4bedd327af588696559f68c264440e2503cc9e6954019473d74ae21"},
{file = "torch-2.6.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:2bb8987f3bb1ef2675897034402373ddfc8f5ef0e156e2d8cfc47cacafdda4a9"},
{file = "torch-2.6.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:b789069020c5588c70d5c2158ac0aa23fd24a028f34a8b4fcb8fcb4d7efcf5fb"},
{file = "torch-2.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:7e1448426d0ba3620408218b50aa6ada88aeae34f7a239ba5431f6c8774b1239"},
{file = "torch-2.6.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:9a610afe216a85a8b9bc9f8365ed561535c93e804c2a317ef7fabcc5deda0989"},
{file = "torch-2.6.0-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:4874a73507a300a5d089ceaff616a569e7bb7c613c56f37f63ec3ffac65259cf"},
{file = "torch-2.6.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:a0d5e1b9874c1a6c25556840ab8920569a7a4137afa8a63a32cee0bc7d89bd4b"},
{file = "torch-2.6.0-cp313-cp313-win_amd64.whl", hash = "sha256:510c73251bee9ba02ae1cb6c9d4ee0907b3ce6020e62784e2d7598e0cfa4d6cc"},
{file = "torch-2.6.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:ff96f4038f8af9f7ec4231710ed4549da1bdebad95923953a25045dcf6fd87e2"},
{file = "torch-2.6.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:9ea955317cfcd3852b1402b62af258ce735c2edeee42ca9419b6bc889e5ae053"},
{file = "torch-2.6.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:bb2c6c3e65049f081940f5ab15c9136c7de40d3f01192541c920a07c7c585b7e"},
{file = "torch-2.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:683410f97984103148e31b38a8631acf31c3034c020c0f4d26171e7626d8317a"},
{file = "torch-2.6.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:265f70de5fd45b864d924b64be1797f86e76c8e48a02c2a3a6fc7ec247d2226c"},
]
[package.dependencies]
filelock = "*"
fsspec = "*"
jinja2 = "*"
networkx = "*"
nvidia-cublas-cu12 = {version = "12.4.5.8", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-cuda-cupti-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-cuda-nvrtc-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-cuda-runtime-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-cudnn-cu12 = {version = "9.1.0.70", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-cufft-cu12 = {version = "11.2.1.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-curand-cu12 = {version = "10.3.5.147", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-cusolver-cu12 = {version = "11.6.1.9", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-cusparse-cu12 = {version = "12.3.1.170", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-cusparselt-cu12 = {version = "0.6.2", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-nccl-cu12 = {version = "2.21.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-nvjitlink-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-nvtx-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
setuptools = {version = "*", markers = "python_version >= \"3.12\""}
sympy = {version = "1.13.1", markers = "python_version >= \"3.9\""}
triton = {version = "3.2.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
typing-extensions = ">=4.10.0"
[package.extras]
opt-einsum = ["opt-einsum (>=3.3)"]
optree = ["optree (>=0.13.0)"]
[[package]]
name = "tqdm"
version = "4.67.1"

View File

@ -22,10 +22,9 @@ opentelemetry-instrumentation-grpc = "^0.53b0"
hf-transfer = "^0.1.9"
sentencepiece = "^0.2.0"
peft = "^0.15"
optimum-habana = "1.17"
transformers = "^4.49"
transformers = "^4.52.4"
numpy = "^1.26"
accelerate = "^0.33"
accelerate = "^1.7.0"
outlines= { version = "^0.0.36", optional = true }
prometheus-client = "^0.21.1"
py-cpuinfo = "^9.0.0"

View File

@ -1,4 +1,4 @@
accelerate==0.33.0 ; python_version >= "3.9" and python_version < "3.13"
accelerate==1.7.0 ; python_version >= "3.9" and python_version < "3.13"
annotated-types==0.7.0 ; python_version >= "3.9" and python_version < "3.13"
attrs==25.3.0 ; python_version >= "3.9" and python_version < "3.13"
certifi==2025.1.31 ; python_version >= "3.9" and python_version < "3.13"
@ -36,19 +36,6 @@ nest-asyncio==1.6.0 ; python_version >= "3.9" and python_version < "3.13"
networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
numba==0.60.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
nvidia-cublas-cu12==12.4.5.8 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cuda-cupti-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cuda-nvrtc-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cuda-runtime-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cudnn-cu12==9.1.0.70 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cufft-cu12==11.2.1.3 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-curand-cu12==10.3.5.147 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cusolver-cu12==11.6.1.9 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cusparse-cu12==12.3.1.170 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cusparselt-cu12==0.6.2 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-nccl-cu12==2.21.5 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-nvjitlink-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-nvtx-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
opentelemetry-api==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-common==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
@ -59,7 +46,6 @@ opentelemetry-instrumentation==0.53b0 ; python_version >= "3.9" and python_versi
opentelemetry-proto==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
optimum-habana==1.17.0 ; python_version >= "3.9" and python_version < "3.13"
optimum==1.24.0 ; python_version >= "3.9" and python_version < "3.13"
outlines==0.0.36 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.2 ; python_version >= "3.9" and python_version < "3.13"
@ -88,9 +74,8 @@ shellingham==1.5.4 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
threadpoolctl==3.6.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
torch==2.6.0 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.67.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.49.0 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.52.4 ; python_version >= "3.9" and python_version < "3.13"
triton==3.2.0 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
typer==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.13.2 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -1,6 +1,4 @@
import os
import psutil
import signal
import sys
import typer
@ -19,6 +17,7 @@ class Quantization(str, Enum):
gptq = "gptq"
awq = "awq"
fp8 = "fp8"
compressed_tensors = "compressed-tensors"
class Dtype(str, Enum):
@ -26,6 +25,11 @@ class Dtype(str, Enum):
bloat16 = "bfloat16"
class KVCacheDtype(str, Enum):
fp8_e4m3fn = "fp8_e4m3fn"
fp8_e5m2 = "fp8_e5m2"
@app.command()
def serve(
model_id: str,
@ -34,6 +38,7 @@ def serve(
quantize: Optional[Quantization] = None,
speculate: Optional[int] = None,
dtype: Optional[Dtype] = None,
kv_cache_dtype: Optional[KVCacheDtype] = None,
trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO",
@ -93,7 +98,8 @@ def serve(
# Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value
dtype = "bfloat16" if dtype is None else dtype.value
logger.info(f"quantize={quantize}")
kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value
logger.info(f"quantize={quantize} kv_cache_dtype={kv_cache_dtype}")
if dtype is not None and quantize not in {
None,
"bitsandbytes",
@ -102,71 +108,11 @@ def serve(
"gptq",
"awq",
"fp8",
"compressed-tensors",
}:
raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)
logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype))
if sharded and os.getenv("ATTENTION", "default") not in {"paged"}:
tgi_file = Path(__file__).resolve().parent / "tgi_service.py"
num_shard = int(os.getenv("WORLD_SIZE", "1"))
logger.info("CLI SHARDED = {}".format(num_shard))
import subprocess
cmd = (
f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}"
)
cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}"
cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}"
cmd += f" --quantize {quantize} --max_input_tokens {max_input_tokens}"
if speculate is not None:
cmd += f"--speculate {speculate}"
logger.info("CLI server start deepspeed ={} ".format(cmd))
sys.stdout.flush()
sys.stderr.flush()
with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc:
do_terminate = False
current_handler = signal.getsignal(signal.SIGTERM)
def terminate_handler(sig, frame):
nonlocal do_terminate
do_terminate = True
if callable(current_handler):
current_handler(sig, frame)
signal.signal(signal.SIGTERM, terminate_handler)
finished = False
while not finished:
try:
if do_terminate:
parent = psutil.Process(proc.pid)
all_procs = parent.children(recursive=True) + [parent]
for p in all_procs:
try:
p.terminate()
except psutil.NoSuchProcess:
pass
_, alive = psutil.wait_procs(all_procs, timeout=30)
for p in alive:
p.kill()
do_terminate = False
proc.wait(timeout=3)
except subprocess.TimeoutExpired:
pass
else:
finished = True
sys.stdout.flush()
sys.stderr.flush()
if proc.returncode != 0:
logger.error(f"{cmd} exited with status = {proc.returncode}")
return proc.returncode
else:
server.serve(
model_id,
lora_adapters,
@ -175,6 +121,7 @@ def serve(
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
uds_path,
max_input_tokens,

View File

@ -1,53 +0,0 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import os
import habana_frameworks.torch as htorch
quant_config = os.getenv("QUANT_CONFIG", "")
is_quantization_enabled = quant_config != ""
if is_quantization_enabled:
os.environ.setdefault("ENABLE_EXPERIMENTAL_FLAGS", "true")
os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true")
os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false")
os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false")
os.environ.setdefault("UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av")
os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
def patch_scoped_linear_all_reduce(model):
from deepspeed.module_inject.layers import LinearAllreduce
from optimum.habana.transformers.models.modeling_all_models import (
ScopedLinearAllReduce,
)
for name, module in model.named_children():
if type(module) is LinearAllreduce:
SL = ScopedLinearAllReduce(mod=module)
setattr(model, name, SL)
patch_scoped_linear_all_reduce(module)
def setup_quantization(model):
if is_quantization_enabled:
htorch.core.quantization._mark_params_as_const(model)
htorch.core.quantization._check_params_as_const(model)
htorch.core.hpu_initialize(model)
return model
def prepare_model_for_quantization(model):
if is_quantization_enabled:
if model.config.model_type in [
"llama",
"falcon",
"qwen2",
"starcoder2",
"gemma",
]:
patch_scoped_linear_all_reduce(model)
from neural_compressor.torch.quantization import FP8Config, convert
config = FP8Config.from_json_file(quant_config)
model = convert(model, config)
return model

View File

@ -12,6 +12,7 @@ from text_generation_server.layers.speculative import SpeculativeHead
# Just to add the `load` methods.
from text_generation_server.layers.layernorm import load_layer_norm
from text_generation_server.layers.conv import load_conv2d
from text_generation_server.layers.fp8 import Fp8Linear
from text_generation_server.layers.lora import (
LoraLinear,
@ -27,6 +28,7 @@ __all__ = [
"TensorParallelEmbedding",
"SpeculativeHead",
"LoraLinear",
"Fp8Linear",
"TensorParallelMultiAdapterLinear",
"TensorParallelAdapterRowLinear",
"load_layer_norm",

View File

@ -10,18 +10,23 @@ from .hpu import (
SUPPORTS_WINDOWING,
attention,
paged_attention,
paged_attention_mla,
set_block_mapping,
)
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
from .kv_cache import KVCache, get_kv_scales
from .kv_cache import KVCache, get_kv_scales, KVCompressCache
__all__ = [
"attention",
"get_kv_scales",
"paged_attention",
"paged_attention_mla",
"set_block_mapping",
"SUPPORTS_WINDOWING",
"KVCache",
"KVCompressCache",
"Seqlen",
"HPUPagedAttentionMetadata",
"trim_seqlen_metadata",

View File

@ -90,6 +90,8 @@ class Seqlen:
def _async_h2d_tensor_copy(source, device="hpu"):
if source is None:
return None
if source.device.type == "hpu":
return source
assert source.device.type == "cpu", "Source tensor is not present in host memory!"
target = torch.empty(source.shape, dtype=source.dtype, device=device)
target.copy_(source, non_blocking=True)

View File

@ -7,15 +7,67 @@ from vllm_hpu_extension.utils import Matmul
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA
import os
from text_generation_server.models.globals import BLOCK_SIZE
import math
SUPPORTS_WINDOWING = False
def fetch_from_cache(cache, blocks):
class FP8Matmul(torch.nn.Module):
def __init__(self, scale_other):
super().__init__()
self.scale_input = torch.tensor(1.0, dtype=torch.bfloat16, device="hpu")
self.scale_other = scale_other
def quant_input(self, x, scale):
return torch.ops.hpu.cast_to_fp8_v2(
x, scale, False, False, torch.float8_e4m3fn
)[0]
def matmul_fp8(
self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None
):
return torch.ops.hpu.fp8_gemm_v2(
A=x,
trans_A=False,
B=other,
trans_B=False,
D=None,
out_dtype=out_dtype,
A_scale_inv=scale_input_inv,
B_scale_inv=scale_other_inv,
bias=None,
accumulate=False,
)
def forward(self, input, other):
qinput = self.quant_input(input, self.scale_input)
qother = self.quant_input(other, self.scale_other)
output = self.matmul_fp8(
qinput,
qother,
out_dtype=torch.bfloat16,
scale_input_inv=1.0 / self.scale_input,
scale_other_inv=1.0 / self.scale_other,
)
return output
class FetchFromCache(torch.nn.Module):
def __init__(self, scale_inv):
super().__init__()
self.scale_inv = scale_inv
def forward(self, cache, blocks):
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
return cache[: blocks.size(0)]
out = cache[: blocks.size(0)]
else:
return cache.index_select(0, blocks)
out = cache.index_select(0, blocks)
if out.dtype == torch.float8_e4m3fn:
out = torch.ops.hpu.cast_from_fp8(out, self.scale_inv, torch.bfloat16)
return out
def attention(
@ -55,6 +107,21 @@ def attention(
return attn_output
def set_block_mapping(hpu_attention_meta: HPUPagedAttentionMetadata, batch_size):
block_mapping = torch.nn.functional.one_hot(
hpu_attention_meta.block_groups, num_classes=batch_size
)
dtype = hpu_attention_meta.block_usage.dtype
device = hpu_attention_meta.block_usage.device
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
mask = mask >= hpu_attention_meta.block_usage.unsqueeze(-1)
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
hpu_attention_meta = hpu_attention_meta._replace(
attn_bias=attn_bias, block_mapping=block_mapping.to(dtype)
)
return hpu_attention_meta
def paged_attention(
query: torch.Tensor,
kv_cache: KVCache,
@ -67,6 +134,7 @@ def paged_attention(
hpu_attention_meta: HPUPagedAttentionMetadata,
):
batch_size, head_num, head_size = query.shape
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
output = ops.flat_pa(
query=query.view(batch_size, 1, head_num * head_size),
key_cache=kv_cache.key,
@ -75,20 +143,59 @@ def paged_attention(
block_mapping=hpu_attention_meta.block_mapping,
block_bias=hpu_attention_meta.attn_bias,
block_groups=hpu_attention_meta.block_groups,
block_size=BLOCK_SIZE,
scale=softmax_scale,
matmul_qk_op=Matmul(),
matmul_av_op=Matmul(),
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
batch2block_matmul_op=Matmul(),
block2batch_matmul_op=Matmul(),
keys_fetch_func=fetch_from_cache,
values_fetch_func=fetch_from_cache,
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
values_fetch_func=FetchFromCache(1.0 / kv_scales.value_scale_cpu),
)
# Reshape the output tensor.
return output.view(batch_size, head_num, head_size)
def paged_attention_mla(
query: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
seqlen: Seqlen,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
hpu_attention_meta: HPUPagedAttentionMetadata,
kv_lora_rank: int = 0,
):
batch_size, head_num, head_size = query.shape
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
output = ops.flat_pa_mla(
query=query,
key_cache=kv_cache.key,
value_cache=None,
block_list=hpu_attention_meta.block_list,
block_mapping=hpu_attention_meta.block_mapping,
block_bias=hpu_attention_meta.attn_bias,
block_groups=hpu_attention_meta.block_groups,
block_size=BLOCK_SIZE,
scale=softmax_scale,
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
batch2block_matmul_op=Matmul(),
block2batch_matmul_op=Matmul(),
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
values_fetch_func=None,
kv_lora_rank=kv_lora_rank,
)
# Reshape the output tensor.
return output.view(batch_size, head_num, -1)
__all__ = [
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"paged_attention_mla",
"set_block_mapping",
]

View File

@ -5,7 +5,6 @@ import torch
from text_generation_server.models.globals import BLOCK_SIZE
from text_generation_server.utils.weights import Weights
from vllm_hpu_extension import cache_ops
@dataclass
@ -50,15 +49,17 @@ class KVCache:
):
"""Construct the key-value cache for a layer."""
## TODO FP8 kv cache support
if dtype is torch.float8_e5m2:
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
self.kv_cache = (
torch.zeros(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
(num_blocks * BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
torch.zeros(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
(num_blocks * BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
@ -101,24 +102,89 @@ class KVCache:
key_cache,
value_cache,
slots,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
kv_scales.key_scale,
kv_scales.value_scale,
)
class KVCompressCache(KVCache):
"""
Key-value cache for attention layers.
"""
kv_cache: torch.Tensor
def __init__(
self,
*,
num_blocks: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
"""Construct the key-value cache for a layer."""
## TODO FP8 kv cache support
if dtype is torch.float8_e5m2:
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
self.kv_cache = torch.zeros(
(num_blocks * BLOCK_SIZE, 1, head_size),
dtype=dtype,
device=device,
)
@property
def dtype(self):
"""Get the data type of the cache."""
return self.kv_cache.dtype
@property
def key(self):
"""Get the key cache."""
return self.kv_cache
@property
def value(self):
"""Get the value cache."""
return self.kv_cache
def store(
self,
*,
key: torch.Tensor,
value: torch.Tensor,
slots: torch.Tensor,
kv_scales: KVScales,
):
"""Store the key and value at the given slots."""
## TODO FP8 kv cache support
if self.kv_cache.dtype == torch.float8_e4m3fn:
key = torch.ops.hpu.cast_to_fp8_v2(
key, kv_scales.key_scale, False, False, torch.float8_e4m3fn
)[0]
self.kv_cache.index_copy_(0, slots, key)
def paged_reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
):
block_idx = slots // BLOCK_SIZE
block_offset = slots % BLOCK_SIZE
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
if key_cache.dtype == torch.float8_e4m3fn:
key = torch.ops.hpu.cast_to_fp8_v2(
key, k_scale, False, False, torch.float8_e4m3fn
)[0]
value = torch.ops.hpu.cast_to_fp8_v2(
value, v_scale, False, False, torch.float8_e4m3fn
)[0]
key_cache.index_copy_(0, slots, key)
value_cache.index_copy_(0, slots, value)
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:

View File

@ -0,0 +1,3 @@
from .loader import CompressedTensorsLoader
__all__ = ["CompressedTensorsLoader"]

View File

@ -0,0 +1,169 @@
from typing import Any, Dict, List, Union
from compressed_tensors import QuantizationConfig, QuantizationStatus
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization import (
QuantizationScheme,
QuantizationType,
find_name_or_class_matches,
)
from loguru import logger
from pydantic import ValidationError
from torch import nn
from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
Weights,
WeightsLoader,
)
# compressed-tensors can match modules as quantization targets. However,
# they need to be objects rather than classes or class names. Since we
# need to match `Linear` targets, make an instance that can be re-used.
_EMPTY_LINEAR: nn.Module = nn.Linear(0, 0)
class CompressedTensorsLoader(WeightsLoader):
"""Loader for checkpoints stored in the compressed-tensors format."""
def __init__(self, config: Dict[str, Any]):
quantization_config_raw = config.get("quantization_config")
if quantization_config_raw is None:
# `compression_config` was renamed to `quantization_config`; support
# retained for backward compatibility.
quantization_config_raw = config.get("compression_config")
if quantization_config_raw is None:
raise ValueError(
"Checkpoint does not have compressed-tensors configuration"
)
try:
quantization_config = QuantizationConfig.model_validate(
quantization_config_raw
)
except ValidationError as e:
raise ValueError("Cannot parse compressed-tensors configuration") from e
if quantization_config.quantization_status not in (
QuantizationStatus.COMPRESSED,
QuantizationStatus.FROZEN,
):
raise ValueError(
f"Model quantization was not finished, status was: {quantization_config.quantization_status}"
)
self.ignore = (
quantization_config.ignore if quantization_config.ignore is not None else []
)
self.loaders = self._get_target_loaders(quantization_config)
for target, loader in self.loaders.items():
log_once(
logger.info,
f"Using {loader} for compressed-tensors target '{target}'",
)
def get_weights(self, weights: Weights, prefix: str):
loader = self._lookup_loader(prefix)
return loader.get_weights(weights, prefix)
def get_weights_col_packed(
self,
weights: "Weights",
prefix: str,
block_sizes: Union[int, List[int]],
):
loader = self._lookup_loader(prefix)
return loader.get_weights_col_packed(weights, prefix, block_sizes)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
loader = self._lookup_loader(prefixes[0])
return loader.get_multi_weights_col(weights, prefixes, dim)
def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int):
loader = self._lookup_loader(prefixes[0])
return loader.get_multi_weights(weights, prefixes, dim)
def get_weights_row(self, weights: Weights, prefix: str):
loader = self._lookup_loader(prefix)
return loader.get_weights_row(weights, prefix)
def _get_target_loaders(
self, quantization_config: QuantizationConfig
) -> Dict[str, WeightsLoader]:
"""
A compressed-tensors checkpoint can use different quantizations
for different targets. This method returns a dictionary with a
loader per target.
"""
loaders: Dict[str, WeightsLoader] = {}
format = quantization_config.format
for group_name, group in quantization_config.config_groups.items():
# The group configuration can be a string, but does that ever
# happen in a serialized quantization config?
assert isinstance(group, QuantizationScheme)
loader = self._create_loader_for_group(format, group_name, group)
# A quantized parameter group can have multiple targets, add the
# loader for all the targets.
for target in group.targets:
if target in loaders:
raise ValueError(
f"Target '{target} has multiple configured loaders'"
)
loaders[target] = loader
return loaders
def _create_loader_for_group(
self, format: str, group_name: str, group: QuantizationScheme
) -> WeightsLoader:
"""
Find and create a loader for the group with the given quantization
scheme.
"""
# NOTE: we ignore group.output_activations because we don't support
# output quantization yet.
input_activations = group.input_activations
weights = group.weights
if (
format
in {
CompressionFormat.float_quantized.value,
CompressionFormat.naive_quantized.value,
}
and weights is not None
and weights.type == QuantizationType.FLOAT
and weights.num_bits == 8
):
# FP W8A8 or W8A16.
return W8ANFpLoader(input_activations=input_activations, weights=weights)
else:
raise ValueError(
f"Group '{group_name}' has unsupported compressed-tensors configurtion"
)
def _lookup_loader(self, prefix: str) -> WeightsLoader:
"""
Look up the loader to use for a given parameter name (prefix).
"""
if len(find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.ignore)) > 0:
return DefaultWeightsLoader(UnquantizedWeight)
# We currently only handle linear layers, so unconditionally pass
# a `Linear` instance.
targets = find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.loaders.keys())
if len(targets) == 0:
raise ValueError(
f"Cannot find compressed-tensors target for prefix: {prefix}"
)
return self.loaders[targets[0]]

View File

@ -0,0 +1,253 @@
from typing import List, Optional, Union
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
from text_generation_server.layers.fp8 import (
Fp8Weight,
_load_scalar_or_matrix_scale,
requantize_with_max_scale,
)
from text_generation_server.utils.weights import Weights, WeightsLoader
class W8ANFpLoader(WeightsLoader):
"""
Loader for W8A8/W8A16 FP compressed-tensors parameters.
"""
def __init__(
self,
*,
input_activations: Optional[QuantizationArgs],
weights: QuantizationArgs,
):
assert weights.type == QuantizationType.FLOAT and weights.num_bits == 8
# We ignore the `strategy` option which sets the scales to be
# per-tensor, per-channel or per-token. What scales are supported
# is dependent on the kernels used (e.g. cutlass can do tokenwise,
# Torch cannot, and FP8-Marlin does not quantize inputs at all).
# So, instead we try to use the best-possible configuration.
self.load_weight_scale = not weights.dynamic
self.load_input_scale = (
input_activations is not None and not input_activations.dynamic
)
self.force_w8a16 = (
input_activations is not None and input_activations.num_bits == 16
)
def __str__(self) -> str:
def scale_to_str(scale):
return "static" if scale else "dynamic"
quantization_type = f"W8A{16 if self.force_w8a16 else 8}"
return f"{self.__class__.__name__} ({quantization_type}, weight: {scale_to_str(self.load_weight_scale)}, input: {scale_to_str(self.load_input_scale)})"
def get_weights(self, weights: "Weights", prefix: str):
w = weights.get_tensor(f"{prefix}.weight")
weight_scale = None
if self.load_weight_scale:
weight_scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
logical_widths = [w.shape[0]]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
w = weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
weight_scale = None
if self.load_weight_scale:
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if weight_scale.numel() > 1:
weight_scale = weights.get_packed_sharded(
f"{prefix}.weight_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
logical_widths = [w.shape[0]]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
if input_scale.numel() > 1:
input_scale = weights.get_packed_sharded(
f"{prefix}.input_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
input_scale = input_scale.reshape(-1).max()
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
w = [
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
]
shapes = [x.shape for x in w]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
weight_scale = None
if self.load_weight_scale:
weight_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p, shape in zip(prefixes, shapes)
]
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
logical_widths = [x[0] for x in shapes]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes]
shapes = [x.shape for x in w]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
weight_scale = None
if self.load_weight_scale:
weight_scale = [
weights.get_tensor(f"{p}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(shape[0])
for p, shape in zip(prefixes, shapes)
]
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
logical_widths = [x[0] for x in shapes]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = [
weights.get_tensor(f"{p}.input_scale", to_dtype=False)
.reshape(-1)
.expand(shape[0])
for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_weights_row(self, weights: "Weights", prefix: str):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
weight_scale = None
if self.load_weight_scale:
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
logical_widths = [w.shape[0]]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)

View File

@ -12,11 +12,151 @@ from text_generation_server.utils.weights import (
from vllm_hpu_extension.ops import scaled_fp8_quant
from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2
import habana_frameworks.torch.utils.experimental as htexp
w8a8_block_fp8_matmul = None
per_token_group_quant_fp8 = None
quant_dtype: torch.dtype = torch.float8_e4m3fn
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
if is_hpu_gaudi2():
FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max
def pad_weight(weight, block_size):
"""Pads a matrix to make its dimensions multiples of block_size."""
M, N = weight.shape[-2:]
block_size_m, block_size_n = block_size
pad_M = (block_size_m - M % block_size_m) % block_size_m
pad_N = (block_size_n - N % block_size_n) % block_size_n
if pad_M == 0 and pad_N == 0:
return weight, M, N # No padding needed
padded_weight = torch.nn.functional.pad(
weight, (0, pad_N, 0, pad_M), mode="constant", value=0
)
return padded_weight, M, N # Return original dimensions for unpadding
def unpad_weight(weight, original_M, original_N, keep_first_dim=False):
"""Removes padding from the matrix to restore its original shape."""
if (weight.shape[-2] == original_M) and (weight.shape[-1] == original_N):
return weight
if keep_first_dim:
return weight[:, :original_M, :original_N]
else:
return weight[:original_M, :original_N]
def pad_block_fp8_weight_naive(weight, weight_scale, block_size):
assert len(block_size) == 2
block_size_m, block_size_n = block_size
weight_scale_m, weight_scale_n = weight_scale.shape[-2:]
weight, orig_M, orig_N = pad_weight(weight, block_size)
M, N = weight.shape[-2:]
assert weight_scale_m == M // block_size_m
assert weight_scale_n == N // block_size_n
return weight, orig_M, orig_N
def dynamic_quant(data, single_scale=False):
if single_scale:
scale = ((torch.abs(data)).max() + 1e-8) / FP8_MAX
else:
scale = ((torch.abs(data)).max(dim=-1).values + 1e-8) / FP8_MAX
scale = scale.unsqueeze(-1)
data_fp8 = torch.ops.hpu.cast_to_fp8_v2(
data, 1.0 / scale, False, False, torch.float8_e4m3fn
)[0]
return data_fp8, scale.float()
def dequant_block_fp8_weight_naive(
weight,
weight_scale,
block_size,
dtype=torch.bfloat16,
original_M=None,
original_N=None,
do_unpad=False,
):
if weight_scale is None:
return weight
assert len(block_size) == 2
weight_shape_len = len(weight.shape)
block_size_m, block_size_n = block_size
# mul scale
if weight_shape_len == 2:
weight_scale_m, weight_scale_n = weight_scale.shape
weight_scale = weight_scale.view(weight_scale_m, 1, weight_scale_n, 1)
weight = weight.view(weight_scale_m, block_size_m, weight_scale_n, block_size_n)
if is_hpu_gaudi2():
fake_weight = weight.cpu().to(dtype).to(weight.device)
dequant_weight = fake_weight * weight_scale.to(dtype)
else:
dequant_weight = weight.to(dtype) * weight_scale.to(dtype)
dequant_weight = dequant_weight.view(
weight_scale_m * block_size_m, weight_scale_n * block_size_n
)
keep_first_dim = False
elif weight_shape_len == 3:
fd, weight_scale_m, weight_scale_n = weight_scale.shape
weight_scale = weight_scale.view(fd, weight_scale_m, 1, weight_scale_n, 1)
weight = weight.view(
fd, weight_scale_m, block_size_m, weight_scale_n, block_size_n
)
if is_hpu_gaudi2():
fake_weight = weight.cpu().to(dtype).to(weight.device)
dequant_weight = fake_weight * weight_scale.to(dtype)
else:
dequant_weight = weight.to(dtype) * weight_scale.to(dtype)
dequant_weight = dequant_weight.view(
fd, weight_scale_m * block_size_m, weight_scale_n * block_size_n
)
keep_first_dim = True
else:
raise ValueError("Only support original weight shape is either 2 or 3")
if do_unpad:
dequant_weight = unpad_weight(
dequant_weight, original_M, original_N, keep_first_dim=keep_first_dim
)
return dequant_weight
def apply_block_fp8_linear_hpu_dynamic(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
x_fp8, x_scale = dynamic_quant(input_2d)
output = torch.ops.hpu.fp8_gemm_v2(
x_fp8,
False,
weight,
True,
None,
torch.bfloat16,
x_scale,
weight_scale,
None,
False,
)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
@ -42,7 +182,7 @@ def per_tensor_dequantize(
) -> torch.Tensor:
device = tensor.device
dtype = torch.bfloat16
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2:
if is_hpu_gaudi2():
# dequant on cpu to avoid nan on gaudi2
tensor = tensor.to("cpu")
@ -67,7 +207,7 @@ def requantize_with_max_scale(
for idx, logical_width in enumerate(logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(
weight[start:end, :], weight_scale[idx], dtype
weight[start:end, :], weight_scale[start:end, :], dtype
)
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
weight_dq, max_w_scale
@ -130,6 +270,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
)
# FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
scale = scale.reshape(-1).expand(w.shape[0])
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
@ -138,10 +283,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
.reshape(-1)
.max()
)
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(0), logical_widths, weights.dtype
)
return Fp8Weight(
weight=w,
@ -176,6 +317,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
block_sizes=block_sizes,
to_dtype=False,
)
scale = scale.reshape(-1).expand(w.shape[0])
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
@ -190,10 +336,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
to_dtype=False,
)
input_scale = input_scale.reshape(-1).max()
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(0), logical_widths, weights.dtype
)
return Fp8Weight(
weight=w,
@ -240,6 +382,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
]
scale = torch.cat(scale, dim=0).reshape(-1)
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes)
@ -252,9 +399,66 @@ class HybridFP8UnquantLoader(WeightsLoader):
else None
)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes]
shapes = [x.shape for x in w]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
scale = [
weights.get_tensor(f"{p}.weight_scale_inv", to_device=False)
for p in prefixes
]
scale = torch.cat(scale, dim=dim)
scale = scale.to(weights.device)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
scale = [
weights.get_tensor(f"{p}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(shape[0])
for p, shape in zip(prefixes, shapes)
]
scale = torch.cat(scale, dim=0).reshape(-1)
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.to(weights.device), logical_widths, weights.dtype
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = [
weights.get_tensor(f"{p}.input_scale", to_dtype=False).reshape(-1)
for p in prefixes
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
return Fp8Weight(
@ -285,7 +489,15 @@ class HybridFP8UnquantLoader(WeightsLoader):
weight_block_size=self.weight_block_size,
)
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
@ -294,10 +506,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
.reshape(-1)
.max()
)
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(0), logical_widths, weights.dtype
)
return Fp8Weight(
weight=w,
weight_scale=scale,
@ -389,6 +598,22 @@ class Fp8Linear(torch.nn.Module):
scale_upper_bound = kwargs.get("scale_upper_bound", None)
weight_block_size = kwargs.get("weight_block_size", None)
if weight_block_size is not None:
weight, orig_M, orig_N = pad_block_fp8_weight_naive(
weight, scale, weight_block_size
)
weight, scale = dynamic_quant(
dequant_block_fp8_weight_naive(
weight,
scale,
weight_block_size,
original_M=orig_M,
original_N=orig_N,
do_unpad=True,
)
)
scale = scale.squeeze(-1)
return cls(
qweight=weight,
scale=scale,
@ -399,60 +624,32 @@ class Fp8Linear(torch.nn.Module):
weight_block_size=weight_block_size,
)
@classmethod
def get_shared_device_identity(cls, device):
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
if device not in cls._device_identity_cache:
cls._device_identity_cache[device] = torch.ones(1, device=device)
return cls._device_identity_cache[device]
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.weight_block_size is not None:
# https://arxiv.org/pdf/2412.19437
# At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and
# scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
# group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
# channels).
qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
output = w8a8_block_fp8_matmul(
qinput,
self.qweight,
scale,
self.scale,
self.weight_block_size,
output_dtype=input.dtype,
if self.weight_block_size is not None or self.input_scale is None:
return apply_block_fp8_linear_hpu_dynamic(
input, self.qweight, self.scale, self.input_scale, self.bias
)
if self.bias is not None:
output = output + self.bias
return output.to(dtype=input.dtype)
qinput, scale = fp8_quantize(
input,
self.input_scale,
scale_upper_bound=self.scale_upper_bound,
scalar=True,
)
output = torch._scaled_mm(
qinput,
self.qweight.t(),
out_dtype=self.dtype,
scale_a=scale,
scale_b=self.scale,
x_fp8 = torch.ops.hpu.cast_to_fp8_v2(
input, 1.0 / self.input_scale, False, False, torch.float8_e4m3fn
)[0]
return torch.ops.hpu.fp8_gemm_v2(
A=x_fp8,
trans_A=False,
B=self.qweight,
trans_B=True,
D=None,
out_dtype=input.dtype,
A_scale_inv=self.input_scale,
B_scale_inv=self.scale,
bias=self.bias,
accumulate=False,
)
if isinstance(output, tuple) and len(output) == 2:
output = output[0]
return output
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
scale = weights.get_tensor(prefix, to_dtype=False)
if scale.numel() > 1:
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
return scale.reshape(-1)
return scale.reshape(-1).expand(shape[0])

View File

@ -4,7 +4,12 @@ from typing import List, Optional, Union
import torch
from loguru import logger
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
from text_generation_server.utils.weights import (
Weight,
Weights,
WeightsLoader,
DefaultWeightsLoader,
)
from .hpu import QuantLinear
@ -72,6 +77,7 @@ class GPTQWeightsLoader(WeightsLoader):
quant_method: str,
quantize: str,
sym: bool,
modules_to_not_convert: List[str],
):
self.bits = bits
self.desc_act = desc_act
@ -79,6 +85,12 @@ class GPTQWeightsLoader(WeightsLoader):
self.quant_method = quant_method
self.quantize = quantize
self.sym = sym
self.modules_to_not_convert = modules_to_not_convert
def is_layer_skipped_quantization(
self, prefix: str, modules_to_not_convert: List[str]
):
return any(module_name in prefix for module_name in modules_to_not_convert)
def get_weights(self, weights: Weights, prefix: str):
self._get_gptq_params(weights)
@ -91,6 +103,9 @@ class GPTQWeightsLoader(WeightsLoader):
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
return DefaultWeightsLoader.get_weights(weights, prefix)
try:
qweight = weights.get_tensor(f"{prefix}.qweight")
except RuntimeError:
@ -145,6 +160,10 @@ class GPTQWeightsLoader(WeightsLoader):
prefix: str,
block_sizes: Union[int, List[int]],
):
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
return DefaultWeightsLoader.get_weights_col_packed(
weights, prefix, block_sizes
)
try:
qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
@ -196,6 +215,8 @@ class GPTQWeightsLoader(WeightsLoader):
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim)
try:
qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
@ -255,6 +276,63 @@ class GPTQWeightsLoader(WeightsLoader):
use_exllama=use_exllama,
)
def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int):
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
return DefaultWeightsLoader.get_multi_weights(weights, prefixes, dim)
try:
qweight = torch.cat(
[weights.get_tensor(f"{p}.qweight") for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
)
scales = torch.cat([weights.get_tensor(f"{p}.scales") for p in prefixes], dim=1)
self._get_gptq_params(weights)
qzeros = torch.cat([weights.get_tensor(f"{p}.qzeros") for p in prefixes], dim=1)
use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act
if self.quantize == "gptq" and self.quant_method == "gptq":
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
).to(dtype=torch.int32)
else:
g_idx = None
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq",
use_exllama=use_exllama,
)
def get_weights_row(self, weights: Weights, prefix: str):
self._get_gptq_params(weights)
@ -263,6 +341,9 @@ class GPTQWeightsLoader(WeightsLoader):
if self.bits != 4:
use_exllama = False
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
return DefaultWeightsLoader.get_weights_row(weights, prefix)
if self.desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False

View File

@ -53,15 +53,10 @@ class FastRMSNorm(nn.Module):
return cls(weight, eps)
def forward(self, hidden_states, residual=None):
from vllm_hpu_extension.kernels import rms_norm
orig_shape = hidden_states.shape
if residual is not None:
residual += hidden_states.view(residual.shape)
else:
hidden_states += residual
residual = hidden_states
# Note: HPUFusedRMSNorm requires 3D tensors as inputs
if len(orig_shape) == 2:
residual = residual.unsqueeze(0)
x = rms_norm().apply(residual, self.weight, self.variance_epsilon)
return x.view(orig_shape), residual.view(orig_shape)
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(self.weight.dtype), residual

View File

@ -2,6 +2,7 @@ from typing import Optional
import torch
import torch.nn as nn
import os
from text_generation_server.utils.weights import Weights
from text_generation_server.layers.fp8 import (
@ -9,12 +10,11 @@ from text_generation_server.layers.fp8 import (
fp8_quantize,
quant_dtype,
normalize_e4m3fn_to_native_float8,
dynamic_quant,
dequant_block_fp8_weight_naive,
)
try:
from .unquantized import fused_moe
except Exception:
fused_moe = None
from text_generation_server.layers.moe.fused_moe import select_experts
import habana_frameworks.torch as htorch
class FP8SparseMoELayer(nn.Module):
@ -47,6 +47,16 @@ class FP8SparseMoELayer(nn.Module):
self.weight_block_size = weights.weights_loader.weight_block_size
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.world_size = weights.process_group.size()
self.rank = weights.process_group.rank()
self.ep_rank = self.rank
self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true"
if self.use_ep:
n_experts = (n_experts + self.world_size - 1) // self.world_size
self.ep_offset = self.ep_rank * n_experts
else:
self.ep_offset = 0
(
self.gate_up_proj,
@ -58,6 +68,8 @@ class FP8SparseMoELayer(nn.Module):
gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name,
weights=weights,
use_ep=self.use_ep,
ep_offset=self.ep_offset,
)
self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
@ -66,29 +78,89 @@ class FP8SparseMoELayer(nn.Module):
n_experts=n_experts,
name=down_proj_name,
weights=weights,
use_ep=self.use_ep,
ep_offset=self.ep_offset,
)
)
if self.weight_block_size is not None:
self.gate_up_proj, self.gate_up_proj_weight_scale = dynamic_quant(
dequant_block_fp8_weight_naive(
self.gate_up_proj,
self.gate_up_proj_weight_scale,
self.weight_block_size,
)
)
self.down_proj, self.down_proj_weight_scale = dynamic_quant(
dequant_block_fp8_weight_naive(
self.down_proj, self.down_proj_weight_scale, self.weight_block_size
)
)
self.gate_up_proj_weight_scale, self.down_proj_weight_scale = (
self.gate_up_proj_weight_scale.squeeze(-1),
self.down_proj_weight_scale.squeeze(-1),
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return fused_moe(
x,
w1=self.gate_up_proj,
w2=self.down_proj,
gating_output=gating_output,
topk=self.topk,
renormalize=self.renormalize,
inplace=True,
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=gating_output,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
top_k=self.topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.n_expert_group,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
use_fp8_w8a8=True,
w1_scale=self.gate_up_proj_weight_scale,
w2_scale=self.down_proj_weight_scale,
a1_scale=self.gate_up_proj_input_scale,
a2_scale=self.down_proj_input_scale,
)
total_num_experts = gating_output.size(-1)
x_fp8, x_scale = dynamic_quant(x, single_scale=True)
if self.use_ep:
moe_n_slice = 1
n_expert_slice = (
total_num_experts + self.world_size - 1
) // self.world_size
else:
moe_n_slice = 1
n_expert_slice = (total_num_experts + moe_n_slice - 1) // moe_n_slice
for i in range(moe_n_slice):
min_expert = i * n_expert_slice
max_expert = min((i + 1) * n_expert_slice, total_num_experts)
w13_list_slice = [
self.gate_up_proj[j, ...] for j in range(min_expert, max_expert)
]
w2_list_slice = [
self.down_proj[j, ...] for j in range(min_expert, max_expert)
]
w13_weight_scale = [
self.gate_up_proj_weight_scale[j, ...]
for j in range(min_expert, max_expert)
]
w2_weight_scale = [
self.down_proj_weight_scale[j, ...]
for j in range(min_expert, max_expert)
]
current_hidden_states = torch.ops.hpu.mixture_of_experts(
hidden_states=x_fp8,
expert_routing_table=topk_ids.to(torch.int64),
router_weights=topk_weights.to(x.dtype),
w12=w13_list_slice,
w3=w2_list_slice,
d_scale_hidden_states=x_scale,
d_scale_w12=w13_weight_scale,
d_scale_w3=w2_weight_scale,
permuted_weights=True,
activation="silu",
experts_min=min_expert + self.ep_offset,
experts_max=max_expert + self.ep_offset - 1,
)
htorch.core.mark_step()
if i == 0:
final_hidden_states = current_hidden_states
else:
final_hidden_states.add_(current_hidden_states)
return final_hidden_states
def _load_expert_weights(
@ -98,13 +170,14 @@ def _load_expert_weights(
n_experts: int,
name: str,
weights: Weights,
ep_offset: int = 0,
) -> torch.Tensor:
all_weight = None
all_weight_scales = None
max_input_scale = None
for i in range(n_experts):
weight = get_weight_fn(prefix, i, name, weights)
weight = get_weight_fn(prefix, i + ep_offset, name, weights)
assert isinstance(weight, Fp8Weight)
@ -147,14 +220,26 @@ def _load_expert_multi_weights_col(
gate_proj_name: str,
up_proj_name: str,
weights: Weights,
use_ep: bool = False,
ep_offset: int = 0,
) -> torch.Tensor:
def get_weight_fn(prefix, i, name, weights):
def get_weight_fn_sharded(prefix, i, name, weights):
return weights.get_multi_weights_col(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
def get_weight_fn(prefix, i, name, weights):
return weights.get_multi_weights(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
return _load_expert_weights(
get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights
get_weight_fn if use_ep else get_weight_fn_sharded,
prefix=prefix,
n_experts=n_experts,
name=None,
weights=weights,
ep_offset=ep_offset if use_ep else 0,
)
@ -164,10 +249,20 @@ def _load_expert_weights_row(
n_experts: int,
name: str,
weights: Weights,
use_ep: bool = False,
ep_offset: int = 0,
) -> torch.Tensor:
def get_weight_fn(prefix, i, name, weights):
def get_weight_fn_sharded(prefix, i, name, weights):
return weights.get_weights_row(f"{prefix}.{i}.{name}")
def get_weight_fn(prefix, i, name, weights):
return weights.get_weights(f"{prefix}.{i}.{name}")
return _load_expert_weights(
get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights
get_weight_fn if use_ep else get_weight_fn_sharded,
prefix=prefix,
n_experts=n_experts,
name=name,
weights=weights,
ep_offset=ep_offset if use_ep else 0,
)

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
from typing import Tuple, Optional
import torch
@ -25,12 +25,36 @@ def grouped_topk(
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
gating_output = gating_output.float()
if e_score_correction_bias is not None:
e_score_correction_bias = e_score_correction_bias.float()
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.shape[0]
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
)
else:
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
@ -41,13 +65,19 @@ def grouped_topk(
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def fused_topk(
@ -63,3 +93,39 @@ def fused_topk(
if renormalize:
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
):
# DeekSeekv2 uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
else:
topk_weights, topk_ids = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
return topk_weights, topk_ids

View File

@ -4,7 +4,9 @@ import torch
import torch.nn as nn
from text_generation_server.utils.weights import UnquantizedWeight, Weights
from vllm_hpu_extension.ops import DynamicFusedMOE
from vllm_hpu_extension.ops import VllmMixtureOfExpertsOp
import habana_frameworks.torch as htorch
import torch.nn.functional as F
class UnquantizedSparseMoELayer(nn.Module):
@ -53,13 +55,29 @@ class UnquantizedSparseMoELayer(nn.Module):
weights=weights,
)
self.hpu_fused_moe = DynamicFusedMOE(n_experts)
self.MoeOp = VllmMixtureOfExpertsOp(n_experts, 0, n_experts - 1)
for i in range(n_experts):
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i])
self.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
self.MoeOp.w2_list[i].set_weight(self.down_proj[i])
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return self.hpu_fused_moe(x, gating_output, self.topk)
htorch.core.mark_step()
routing_weights = F.softmax(gating_output, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(
routing_weights, self.topk, dim=-1
)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(x.dtype)
final_hidden_states = self.MoeOp(
hidden_states=x,
expert_routing_table=selected_experts,
router_weights=routing_weights,
permuted_weights=True,
activation="silu",
)
return final_hidden_states.view(-1, x.shape[1])
def _load_expert_multi_weights_col(

View File

@ -36,9 +36,7 @@ class PositionRotaryEmbedding(nn.Module):
self._sin_k_cached = None
self.scaling_factor = scaling_factor
self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, inv_freq.device, max_position_embeddings
)
self.max_position_embeddings = max_position_embeddings
def forward(
self,
@ -270,7 +268,9 @@ class PositionRotaryEmbedding(nn.Module):
self._sin_cached = torch.sin(freqs).to(dtype)
def get_cos_sin(self, position_ids: torch.Tensor):
self._update_cos_sin_cache(
torch.float32, position_ids.device, seqlen=self.max_position_embeddings
)
cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids)
@ -298,9 +298,6 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
self._cos_k_cached = None
self._sin_k_cached = None
self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, short_inv_freq.device, max_position_embeddings
)
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
@ -354,9 +351,6 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
self._cos_k_cached = None
self._sin_k_cached = None
self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, short_inv_freq.device, max_position_embeddings
)
def _update_cos_sin_cache(self, dtype, device, seqlen):
if (
@ -470,9 +464,6 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
mscale_all_dim: float,
):
inv_freq = _create_inv_freq(dim, base, device)
super().__init__(
inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor
)
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
@ -487,6 +478,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
/ get_mscale(self.scaling_factor, mscale_all_dim)
* self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
@ -600,6 +592,9 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
position_ids: torch.Tensor,
):
slen = position_ids.shape[0]
self._update_cos_sin_cache(
torch.float32, position_ids.device, seqlen=self.max_position_embeddings
)
cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])

View File

@ -5,7 +5,6 @@ import os
from loguru import logger
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download, HfApi
from typing import Optional
from pathlib import Path
@ -16,9 +15,6 @@ import enum
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.bloom import BLOOM
from text_generation_server.models.starcoder import StarCoder
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
PhiMoEConfig,
)
@ -32,7 +28,6 @@ from text_generation_server.utils.adapter import (
from text_generation_server.adapters.lora import LoraWeights
from text_generation_server.utils.log import log_master
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
__all__ = [
"Model",
@ -40,12 +35,9 @@ __all__ = [
"Seq2SeqLM",
"get_model_with_lora_adapters",
]
from text_generation_server.models.globals import ATTENTION
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
VLM_BATCH_TYPES = set()
FLASH_ATTENTION = False
if ATTENTION == "paged":
FLASH_ATTENTION = True
try:
@ -63,6 +55,9 @@ try:
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_llama4_modeling import (
Llama4ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
FlashCohereForCausalLM,
)
@ -83,9 +78,6 @@ try:
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
FlashGPTNeoXForCausalLM,
)
from text_generation_server.models.pali_gemma import (
PaliGemmaBatch,
)
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration,
)
@ -109,6 +101,12 @@ try:
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_qwen3_modeling import (
Qwen3ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_qwen3_moe_modeling import (
Qwen3MoeForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
@ -140,10 +138,23 @@ except ImportError as e:
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
SUPPORTS_WINDOWING = False
FLASH_ATTENTION = False
VLM_BATCH_TYPES = set()
if FLASH_ATTENTION:
__all__.append(FlashCausalLM)
from text_generation_server.models.flash_vlm_causal_lm import (
FlashVlmCausalLMBatch,
)
VLM_BATCH_TYPES = {
FlashVlmCausalLMBatch,
FlashMllamaCausalLMBatch,
}
__all__.append(VLM_BATCH_TYPES)
class ModelType(enum.Enum):
DEEPSEEK_V2 = {
@ -179,6 +190,11 @@ class ModelType(enum.Enum):
"name": "Llama",
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
}
LLAMA4 = {
"type": "llama4",
"name": "Llama4",
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
}
PHI3 = {
"type": "phi3",
"name": "Phi 3",
@ -274,6 +290,16 @@ class ModelType(enum.Enum):
"name": "Qwen 2.5 VL",
"url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e",
}
QWEN3 = {
"type": "qwen3",
"name": "Qwen 3",
"url": "https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f",
}
QWEN3_MOE = {
"type": "qwen3_moe",
"name": "Qwen 3 Moe",
"url": "https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f",
}
GALACTICA = {
"type": "galactica",
"name": "Galactica",
@ -324,6 +350,7 @@ def get_model(
quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[torch.dtype],
kv_cache_dtype: Optional[str],
trust_remote_code: bool,
max_input_tokens: int,
) -> Model:
@ -449,6 +476,11 @@ def get_model(
model_type = config_dict["model_type"]
if kv_cache_dtype == "fp8_e4m3fn":
kv_cache_dtype = torch.float8_e4m3fn
elif kv_cache_dtype == "fp8_e5m2":
kv_cache_dtype = torch.float8_e5m2
else:
kv_cache_dtype = dtype
if FLASH_ATTENTION:
@ -589,6 +621,20 @@ def get_model(
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == LLAMA4:
print(f"Llama4 model detected: {model_id}")
return FlashVlmCausalLM(
model_id=model_id,
model_class=Llama4ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
support_chunking=False,
)
elif model_type == BAICHUAN:
return FlashCausalLM(
model_id=model_id,
@ -737,6 +783,8 @@ def get_model(
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# TODO: Fix bug in rust image_text_replacement implementation
support_chunking=False,
)
elif model_type == QWEN2_5_VL:
return FlashVlmCausalLM(
@ -752,6 +800,32 @@ def get_model(
lora_adapter_ids=lora_adapter_ids,
config_class=Qwen2_5_VLConfig,
processor_class=Qwen2_5_VLProcessor,
# TODO: Fix bug in rust image_text_replacement implementation
support_chunking=False,
)
elif model_type == QWEN3:
return FlashCausalLM(
model_id=model_id,
model_class=Qwen3ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == QWEN3_MOE:
return FlashCausalLM(
model_id=model_id,
model_class=Qwen3MoeForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == MLLAMA:
return FlashMllamaCausalLM(
@ -765,6 +839,7 @@ def get_model(
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
support_chunking=False,
)
elif model_type == IDEFICS2:
return FlashVlmCausalLM(
@ -809,7 +884,6 @@ def get_model(
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
batch_class=PaliGemmaBatch,
)
elif model_type == LLAVA_NEXT:
return FlashVlmCausalLM(
@ -823,60 +897,6 @@ def get_model(
trust_remote_code=trust_remote_code,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.custom_modeling.mllama import (
MllamaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
adapt_transformers_to_gaudi()
if SDP_ON_BF16 == 1:
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
if model_type == "gpt_bigcode":
return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
if model_type == "bloom":
return BLOOM(
model_id=model_id,
revision=revision,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == "llava_next":
return VlmCausalLM(
model_class=LlavaNextForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=None,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == "mllama":
return VlmCausalLM(
model_class=MllamaForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=None,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise ValueError(f"Unsupported model type {model_type}")
@ -890,6 +910,7 @@ def get_model_with_lora_adapters(
quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[torch.dtype],
kv_cache_dtype: Optional[str],
trust_remote_code: bool,
max_input_tokens: int,
adapter_to_index: Dict[str, int],
@ -903,6 +924,7 @@ def get_model_with_lora_adapters(
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
max_input_tokens,
)

View File

@ -1,52 +0,0 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import torch
from typing import Optional, Type
from transformers import PreTrainedTokenizerBase
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
class BloomCausalLMBatch(CausalLMBatch):
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super().from_pb(
pb=pb,
tokenizer=tokenizer,
dtype=dtype,
device=device,
)
batch.keys_head_dim_last = False
return batch
class BLOOM(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(BLOOM, self).__init__(
model_id=model_id,
revision=revision,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return BloomCausalLMBatch

View File

@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -51,6 +52,8 @@ from habana_frameworks.torch.hpex.kernels import (
apply_rotary_pos_emb,
)
import habana_frameworks.torch as htorch
class CohereRotary(PositionRotaryEmbedding):
def forward(
@ -413,6 +416,10 @@ class FlashCohereModel(torch.nn.Module):
seqlen: torch.Tensor,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
@ -420,7 +427,9 @@ class FlashCohereModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -433,6 +442,8 @@ class FlashCohereModel(torch.nn.Module):
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -26,6 +26,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -44,6 +45,7 @@ from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
from vllm_hpu_extension.ops import DynamicFusedMOE
import habana_frameworks.torch as htorch
class DbrxAttentionConfig(PretrainedConfig):
@ -677,13 +679,19 @@ class DbrxModel(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -696,6 +704,8 @@ class DbrxModel(torch.nn.Module):
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -33,6 +33,7 @@ from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
set_block_mapping,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
@ -40,6 +41,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.weights import Weights
import habana_frameworks.torch as htorch
class DeepseekV2Config(PretrainedConfig):
@ -568,6 +570,10 @@ class DeepseekV2Model(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
@ -575,6 +581,9 @@ class DeepseekV2Model(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -587,6 +596,8 @@ class DeepseekV2Model(torch.nn.Module):
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -28,11 +28,13 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
TensorParallelRowLinear,
get_linear,
Fp8Linear,
)
from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
paged_attention_mla,
set_block_mapping,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
@ -40,6 +42,19 @@ from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.weights import Weights
import habana_frameworks.torch as htorch
def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor:
if isinstance(layer, Fp8Linear):
eye = torch.eye(
layer.qweight.shape[-1], dtype=torch.bfloat16, device=layer.qweight.device
)
dequant_weights = layer(eye)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
class DeepseekV3Config(PretrainedConfig):
@ -249,6 +264,44 @@ class DeepseekV3Attention(torch.nn.Module):
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj.linear).T
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.value_head_size,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.value_head_size], dim=-1
)
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)
def _q_proj_and_k_up_proj(self, x):
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
q_nope, q_pe = (
q_proj(x)
.view(-1, self.num_heads, self.head_size)
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
return ql_nope.transpose(0, 1), q_pe
def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.value_head_size)
return self.o_proj(x)
def forward(
self,
hidden_states: torch.Tensor,
@ -261,14 +314,9 @@ class DeepseekV3Attention(torch.nn.Module):
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
if self.q_lora_rank is None:
query = self.q_proj(hidden_states)
hidden_states_or_q_c = hidden_states
else:
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
query = query.view(-1, self.num_heads, self.head_size)
_, query_pe = torch.split(
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
hidden_states_or_q_c = self.q_a_layernorm(self.q_a_proj(hidden_states))[0]
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, key_pe = torch.split(
@ -276,13 +324,18 @@ class DeepseekV3Attention(torch.nn.Module):
)
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
)
kv_c_normed = self.kv_a_layernorm(compressed_kv.contiguous())[0]
key_nope, value = torch.split(
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
# Prefill
if cu_seqlen_prefill is not None:
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
query = q_proj(hidden_states_or_q_c)
query = query.view(-1, self.num_heads, self.head_size)
query_nope, query_pe = torch.split(
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
else:
query_nope, query_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
batch_size, heads, head_dim = query_pe.shape
query_pe = (
@ -297,7 +350,30 @@ class DeepseekV3Attention(torch.nn.Module):
.reshape(batch_size, heads, head_dim)
)
self.rotary_emb(query_pe, key_pe, cos, sin)
latent_vec_k = torch.concat(
(kv_c_normed, key_pe.view(-1, self.qk_rope_head_dim)), dim=-1
)
latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)
latent_vec_k = latent_vec_k.unflatten(0, (slots.size(0), -1))
kv_cache.store(
key=latent_vec_k,
value=None,
slots=slots,
kv_scales=self.kv_scales,
)
if cu_seqlen_prefill is not None:
kv = self.kv_b_proj(kv_c_normed).view(
-1,
self.num_key_value_heads,
self.qk_nope_head_dim + self.value_head_size,
)
key_nope, value = torch.split(
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
)
query[..., self.qk_nope_head_dim :] = query_pe
key = torch.empty_like(query)
key[..., : self.qk_nope_head_dim] = key_nope
@ -315,15 +391,6 @@ class DeepseekV3Attention(torch.nn.Module):
value, (0, self.head_pad_size - self.value_head_size), value=0
)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query=query,
@ -334,9 +401,15 @@ class DeepseekV3Attention(torch.nn.Module):
seqlen=seqlen,
softmax_scale=self.softmax_scale,
)
# Decode
attn_output = attn_output[..., : self.value_head_size]
return self.o_proj(
attn_output.reshape(-1, self.num_heads * self.value_head_size)
)
else:
attn_output = paged_attention(
# Decode
query = torch.cat([query_nope, query_pe], dim=-1)
attn_output = paged_attention_mla(
query,
kv_cache,
self.kv_head_mapping,
@ -344,14 +417,10 @@ class DeepseekV3Attention(torch.nn.Module):
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
kv_lora_rank=self.kv_lora_rank,
)
# Remove padding.
attn_output = attn_output[..., : self.value_head_size]
return self.o_proj(
attn_output.reshape(-1, self.num_heads * self.value_head_size)
)
attn_output = self._v_up_proj_and_o_proj(attn_output)
return attn_output
class DeepseekV3MLP(nn.Module):
@ -577,6 +646,10 @@ class DeepseekV3Model(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
@ -584,6 +657,9 @@ class DeepseekV3Model(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -596,6 +672,8 @@ class DeepseekV3Model(torch.nn.Module):
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -46,6 +47,7 @@ from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class Gemma2Config(PretrainedConfig):
@ -465,6 +467,10 @@ class FlashGemma2Model(torch.nn.Module):
adapter_data: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
@ -472,6 +478,10 @@ class FlashGemma2Model(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -485,6 +495,8 @@ class FlashGemma2Model(torch.nn.Module):
adapter_data,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -44,6 +45,7 @@ from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class GemmaConfig(PretrainedConfig):
@ -387,6 +389,10 @@ class FlashGemmaModel(torch.nn.Module):
adapter_data: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
@ -394,6 +400,9 @@ class FlashGemmaModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -406,6 +415,8 @@ class FlashGemmaModel(torch.nn.Module):
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -27,6 +27,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -38,6 +39,7 @@ from text_generation_server.layers import (
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
import habana_frameworks.torch as htorch
def load_qkv(config, prefix: str, weights, head_size, num_heads):
@ -382,9 +384,17 @@ class FlashGPT2Model(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -395,6 +405,8 @@ class FlashGPT2Model(torch.nn.Module):
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states = self.norm(hidden_states)

View File

@ -28,6 +28,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -48,6 +49,7 @@ from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
apply_rotary_pos_emb,
)
import habana_frameworks.torch as htorch
def load_attention(config, prefix: str, weights):
@ -323,6 +325,10 @@ class FlashGPTJModel(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.wte(input_ids)
# Get rotary cos and sin for this forward
@ -330,6 +336,9 @@ class FlashGPTJModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -342,6 +351,8 @@ class FlashGPTJModel(torch.nn.Module):
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.ln_f(hidden_states, residual)

View File

@ -26,7 +26,7 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
import habana_frameworks.torch as htorch
from text_generation_server.layers.attention import (
KVCache,
get_kv_scales,
@ -35,6 +35,7 @@ from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoE
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -143,6 +144,8 @@ class FlashLlamaAttention(torch.nn.Module):
config.num_key_value_heads = getattr(
config, "num_key_value_heads", config.num_attention_heads
)
if config.model_type != "llama4_text":
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
@ -547,6 +550,11 @@ class FlashLlamaModel(torch.nn.Module):
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
cross_attention_states=None,
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
@ -554,6 +562,9 @@ class FlashLlamaModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -568,6 +579,8 @@ class FlashLlamaModel(torch.nn.Module):
cross_attention_states,
hpu_attention_meta=hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -163,25 +163,13 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
)
return inputs_embeds
def forward(
def get_vision_embeds(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused for this model
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
# 1. Extract the input embeddings
@ -216,8 +204,7 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
self.config.vision_config.image_size // self.config.vision_config.patch_size
)
new_image_features = []
@ -254,9 +241,7 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
@ -264,10 +249,38 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
return image_features.view(-1, image_features.shape[-1])
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_features
input_ids, inputs_embeds, vision_embeds
)
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
):
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,

View File

@ -30,6 +30,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -45,6 +46,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
import habana_frameworks.torch as htorch
class MistralConfig(PretrainedConfig):
@ -109,7 +111,8 @@ class MistralAttention(torch.nn.Module):
)
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
if hasattr(config, "head_dim"):
if getattr(config, "head_dim", None) is not None:
self.head_size = config.head_dim
else:
self.head_size = self.hidden_size // self.num_heads
@ -395,12 +398,19 @@ class MistralModel(torch.nn.Module):
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
adapter_data: Optional[torch.Tensor] = None,
):
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -414,6 +424,8 @@ class MistralModel(torch.nn.Module):
adapter_data,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states

View File

@ -37,6 +37,7 @@ from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
set_block_mapping,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
@ -44,6 +45,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class MixtralConfig(PretrainedConfig):
@ -445,6 +447,10 @@ class MixtralModel(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
@ -452,6 +458,9 @@ class MixtralModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -464,6 +473,8 @@ class MixtralModel(torch.nn.Module):
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
@ -499,7 +510,6 @@ class FlashMixtralForCausalLM(torch.nn.Module):
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
position_ids,

View File

@ -38,6 +38,7 @@ from text_generation_server.models.custom_modeling.flash_llama_modeling import (
)
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA
import habana_frameworks.torch as htorch
def _prepare_aspect_ratio_attention_mask(
@ -236,10 +237,19 @@ class MllamaVisionSdpaAttention(nn.Module):
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn_output = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
attn_output = fsdpa_op(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
scale=None,
softmax_mode="None",
recompute_mode=None,
valid_sequence_lengths=None,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
@ -320,6 +330,9 @@ class MllamaVisionEncoder(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
):
encoder_states = [hidden_states]
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for encoder_layer in self.layers:
layer_outputs = encoder_layer(
hidden_states,
@ -328,6 +341,8 @@ class MllamaVisionEncoder(nn.Module):
hidden_states = layer_outputs
encoder_states.append(hidden_states)
if lazy_mode:
htorch.core.mark_step()
return hidden_states, encoder_states
@ -699,8 +714,6 @@ class MllamaTextCrossAttention(nn.Module):
# key_states = key_states.repeat(1, self.num_key_value_groups, 1)
# value_states = value_states.repeat(1, self.num_key_value_groups, 1)
causal = False
# logger.info(
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
# )
@ -715,7 +728,7 @@ class MllamaTextCrossAttention(nn.Module):
value_states,
attn_mask=None,
dropout_p=0.0,
is_causal=causal,
is_causal=False,
scale=None,
softmax_mode="None",
recompute_mode=None,

View File

@ -29,6 +29,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -47,6 +48,7 @@ from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class GPTNeoXConfig(TransformersGPTNeoXConfig):
@ -353,6 +355,10 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_in(input_ids)
# Get rotary cos and sin for this forward
@ -360,6 +366,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -372,6 +381,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.final_layer_norm(hidden_states, residual)

View File

@ -62,10 +62,40 @@ class PaliGemmaForConditionalGeneration(nn.Module):
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
self.dtype = weights.dtype
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
pixel_values = pixel_values.to(dtype=self.dtype)
image_outputs = self.vision_tower(pixel_values)
last_hidden_state = self.post_vision_tower_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multi_modal_projector(last_hidden_state)
image_features = image_features.view(-1, image_features.shape[-1])
return image_features
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
mask = input_ids == self.config.image_token_index
inputs_embeds[mask] = vision_embeds
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -73,32 +103,13 @@ class PaliGemmaForConditionalGeneration(nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused here
pixel_attention_mask: Optional[torch.BoolTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids)
# TODO This is odd but apparently pali gemma position ids start at 1.
if cu_seqlen_prefill is not None:
position_ids += 1
if pixel_values is not None:
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
image_outputs = self.vision_tower(pixel_values)
last_hidden_state = self.post_vision_tower_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multi_modal_projector(last_hidden_state)
# mask where image or padding tokens
mask = input_ids == self.config.image_token_index
# insert image features into input embeddings
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,

View File

@ -9,6 +9,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -26,6 +27,7 @@ from text_generation_server.layers.layernorm import (
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
import habana_frameworks.torch as htorch
class PhiConfig(PretrainedConfig):
@ -346,6 +348,10 @@ class FlashPhiModel(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
@ -353,6 +359,9 @@ class FlashPhiModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -365,6 +374,8 @@ class FlashPhiModel(torch.nn.Module):
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)

View File

@ -18,7 +18,6 @@
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)

View File

@ -8,6 +8,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -22,6 +23,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
import habana_frameworks.torch as htorch
def load_attention(config, prefix, weights):
@ -287,6 +289,10 @@ class Qwen2Model(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
@ -294,6 +300,9 @@ class Qwen2Model(torch.nn.Module):
)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states = layer(
hidden_states,
@ -306,6 +315,8 @@ class Qwen2Model(torch.nn.Module):
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states)
@ -353,7 +364,6 @@ class Qwen2ForCausalLM(torch.nn.Module):
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(

View File

@ -0,0 +1,359 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, List
import torch
from torch import nn
import habana_frameworks.torch as htorch
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers import (
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelColumnLinear,
SpeculativeHead,
)
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from .flash_qwen2_modeling import Qwen2MLP
from text_generation_server.layers.rotary import PositionRotaryEmbedding
class Qwen3Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config, prefix, weights, layer_idx):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
self.num_key_value_groups = (
config.num_attention_heads // config.num_key_value_heads
)
self.num_heads = config.num_attention_heads
self.attention_dropout = config.attention_dropout
self.softmax_scale = self.head_dim**-0.5
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_dim,
base=config.rope_theta,
device=weights.device,
)
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
)
self.q_norm = FastRMSNorm.load(
prefix=f"{prefix}.q_norm",
weights=weights,
eps=config.rms_norm_eps,
)
self.k_norm = FastRMSNorm.load(
prefix=f"{prefix}.k_norm",
weights=weights,
eps=config.rms_norm_eps,
)
self.sliding_window = config.sliding_window
if not (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
self.sliding_window = None
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
qkv = self.query_key_value(hidden_states)
query_states, key_states, value_states = qkv.split(
[
self.head_dim * self.num_heads,
self.head_dim * self.num_key_value_heads,
self.head_dim * self.num_key_value_heads,
],
dim=1,
)
query_states, _ = self.q_norm(query_states.view(hidden_shape))
key_states, _ = self.k_norm(key_states.view(hidden_shape))
value_states = value_states.view(hidden_shape)
self.rotary_emb(query_states, key_states, cos, sin)
kv_cache.store(
key=key_states,
value=value_states,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
attn_output = attention(
query=query_states,
key=key_states,
value=value_states,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past,
)
# Decode
else:
attn_output = paged_attention(
query_states,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
return self.o_proj(attn_output)
class Qwen3DecoderLayer(nn.Module):
def __init__(self, config, prefix, weights, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Qwen3Attention(
config=config,
prefix=f"{prefix}.self_attn",
weights=weights,
layer_idx=layer_idx,
)
self.mlp = Qwen2MLP(config=config, prefix=f"{prefix}.mlp", weights=weights)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
) -> torch.Tensor:
residual = hidden_states
hidden_states, _ = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states, _ = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Qwen3Model(nn.Module):
def __init__(self, config, prefix: str, weights):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.layers = nn.ModuleList(
[
Qwen3DecoderLayer(
config=config,
prefix=f"{prefix}.layers.{layer_idx}",
weights=weights,
layer_idx=layer_idx,
)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids,
)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states)
# add hidden states from the last decoder layer
return hidden_states
class Qwen3ForCausalLM(nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.model = Qwen3Model(config=config, prefix="model", weights=weights)
self.vocab_size = config.vocab_size
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
else:
suffix = "lm_head"
self.lm_head = SpeculativeHead.load(
config,
prefix=f"{prefix}.{suffix}" if prefix else suffix,
weights=weights,
)
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = self.model(
inputs_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits

View File

@ -0,0 +1,542 @@
# coding=utf-8
# Copyright 5 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Type
import torch
from torch import nn
import torch.nn.functional as F
from text_generation_server.layers.attention import (
attention,
paged_attention,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers import (
TensorParallelEmbedding,
TensorParallelColumnLinear,
TensorParallelRowLinear,
SpeculativeHead,
FastLinear,
)
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from .flash_qwen2_modeling import Qwen2MLP
from .flash_qwen3_modeling import Qwen3Attention
from transformers.activations import ACT2FN
from text_generation_server.layers.rotary import PositionRotaryEmbedding
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class Qwen3MoeAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config, prefix, weights, layer_idx):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = (
config.num_attention_heads // config.num_key_value_heads
)
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = FastLinear.load(
config, f"{prefix}.q_proj", weights, bias=config.attention_bias
)
self.k_proj = FastLinear.load(
config, f"{prefix}.k_proj", weights, bias=config.attention_bias
)
self.v_proj = FastLinear.load(
config, f"{prefix}.v_proj", weights, bias=config.attention_bias
)
self.o_proj = FastLinear.load(
config, f"{prefix}.o_proj", weights, bias=config.attention_bias
)
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_dim,
base=config.rope_theta,
device=weights.device,
)
self.q_norm = FastRMSNorm.load(
prefix=f"{prefix}.q_norm",
weights=weights,
eps=config.rms_norm_eps,
)
self.k_norm = FastRMSNorm.load(
prefix=f"{prefix}.k_norm",
weights=weights,
eps=config.rms_norm_eps,
)
self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_key_value_groups)
self.sliding_window = config.sliding_window
if not (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
self.sliding_window = None
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states, _ = self.q_norm(self.q_proj(hidden_states).view(hidden_shape))
key_states, _ = self.k_norm(self.k_proj(hidden_states).view(hidden_shape))
value_states = self.v_proj(hidden_states).view(hidden_shape)
self.rotary_emb(query_states, key_states, cos, sin)
# query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
kv_cache.store(
key=key_states,
value=value_states,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
attn_output = attention(
query=query_states,
key=key_states,
value=value_states,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.scaling,
window_size_left=self.max_past,
)
# Decode
else:
attn_output = paged_attention(
query_states,
kv_cache,
self.kv_head_mapping,
self.scaling,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output
class Qwen3MoE(nn.Module):
def __init__(self, prefix, config, moe_layer_cls: Type[MoELayer], weights):
super().__init__()
# gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
self.moe = moe_layer_cls(
n_expert_group=None,
n_experts=config.num_experts,
prefix=f"{prefix}.experts",
renormalize=True,
topk=config.num_experts_per_tok,
topk_group=None,
weights=weights,
)
# gate_proj_name="w1",
# up_proj_name="w3",
# down_proj_name="w2",
assert isinstance(self.moe, MoELayer)
self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor:
router_logits = self.gate(x)
out = self.moe(x, gating_output=router_logits)
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out.view(*x.shape)
class Qwen3MoeMLP(nn.Module):
def __init__(self, prefix, config, weights, intermediate_size=None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = (
intermediate_size
if intermediate_size is not None
else config.intermediate_size
)
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
gate_up_states = self.gate_up_proj(x)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
class Qwen3MoeSparseMoeBlock(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
# gating
# self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
self.experts = nn.ModuleList(
[
Qwen3MoeMLP(
prefix=f"{prefix}.experts.{i}",
config=config,
weights=weights,
intermediate_size=config.moe_intermediate_size,
)
for i in range(self.num_experts)
]
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
input_shape = hidden_states.shape
_, hidden_dim = hidden_states.shape
# hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=hidden_states.dtype)
routing_weights, selected_experts = torch.topk(
routing_weights, self.top_k, dim=-1
)
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(input_shape), dtype=hidden_states.dtype, device=hidden_states.device
)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.num_experts
).permute(2, 1, 0)
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = (
expert_layer(current_state) * routing_weights[top_x, idx, None]
)
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(
0, top_x, current_hidden_states.to(hidden_states.dtype)
)
final_hidden_states = final_hidden_states.reshape(input_shape)
return final_hidden_states
class Qwen3MoeDecoderLayer(nn.Module):
def __init__(self, config, prefix, weights, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
if config.num_key_value_heads // weights.process_group.size() > 0:
self.self_attn = Qwen3Attention(
config,
prefix=f"{prefix}.self_attn",
weights=weights,
layer_idx=layer_idx,
)
else:
self.self_attn = Qwen3MoeAttention(
config,
prefix=f"{prefix}.self_attn",
weights=weights,
layer_idx=layer_idx,
)
moe_layer_cls = (
SparseMoELayer if SparseMoELayer.is_supported(weights) else DenseMoELayer
)
if (layer_idx not in config.mlp_only_layers) and (
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
):
self.mlp = Qwen3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
# self.mlp = Qwen3MoeSparseMoeBlock(f"{prefix}.mlp", config, weights)
else:
self.mlp = Qwen2MLP(config=config, prefix=f"{prefix}.mlp", weights=weights)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
) -> torch.Tensor:
if residual is None:
residual = hidden_states
hidden_states, _ = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states, _ = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Qwen3MoeModel(nn.Module):
def __init__(self, config, prefix: str, weights):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.layers = nn.ModuleList(
[
Qwen3MoeDecoderLayer(
config=config,
prefix=f"{prefix}.layers.{layer_idx}",
weights=weights,
layer_idx=layer_idx,
)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids,
)
residual = None
for i, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
hpu_attention_meta,
)
hidden_states, _ = self.norm(hidden_states)
# add hidden states from the last decoder layer
return hidden_states
class Qwen3MoeForCausalLM(nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.model = Qwen3MoeModel(config=config, prefix="model", weights=weights)
self.vocab_size = config.vocab_size
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
else:
suffix = "lm_head"
self.lm_head = SpeculativeHead.load(
config,
prefix=f"{prefix}.{suffix}" if prefix else suffix,
weights=weights,
)
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = self.model(
inputs_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits

View File

@ -18,9 +18,11 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.attention import (
attention,
paged_attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
import habana_frameworks.torch as htorch
def load_row(config, prefix: str, weights, bias: bool):
@ -627,6 +629,10 @@ class FlashRWModel(FlashRWPreTrainedModel):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.word_embeddings(input_ids)
# Get rotary cos and sin for this forward
@ -634,6 +640,9 @@ class FlashRWModel(FlashRWPreTrainedModel):
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.h):
hidden_states, residual = layer(
hidden_states,
@ -646,6 +655,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.ln_f(hidden_states, residual)

View File

@ -8,6 +8,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -23,6 +24,7 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
import habana_frameworks.torch as htorch
def load_multi_mqa(
@ -436,12 +438,19 @@ class FlashSantacoderModel(nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
if self.process_group.size() > 1:
torch.distributed.all_reduce(hidden_states, group=self.process_group)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -452,6 +461,8 @@ class FlashSantacoderModel(nn.Module):
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.ln_f(hidden_states, residual)

View File

@ -29,6 +29,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -50,6 +51,7 @@ from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class Starcoder2Config(PretrainedConfig):
@ -510,6 +512,10 @@ class Starcoder2Model(torch.nn.Module):
adapter_data,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
@ -517,6 +523,9 @@ class Starcoder2Model(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
@ -530,6 +539,8 @@ class Starcoder2Model(torch.nn.Module):
adapter_data,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
@ -578,7 +589,6 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
position_ids,

View File

@ -734,33 +734,20 @@ class Idefics2ForConditionalGeneration(nn.Module):
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds
def forward(
def get_vision_embeds(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
# Unused here
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None:
assert pixel_values is not None
batch_size, num_images, num_channels, height, width = pixel_values.shape
all_states = []
all_pixel_values = pixel_values
all_pixel_mask = pixel_attention_mask
for i in range(batch_size):
pixel_values = all_pixel_values.to(
dtype=self.dtype
) # fp16 compatibility
pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility
pixel_values = pixel_values[i : i + 1]
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
@ -813,9 +800,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
conv_kernel,
stride=patch_size,
).squeeze(1)
patch_attention_mask = torch.eq(
patches_subgrid, (patch_size * patch_size)
)
patch_attention_mask = torch.gt(patches_subgrid, 0)
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
@ -830,12 +815,36 @@ class Idefics2ForConditionalGeneration(nn.Module):
)
all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0)
return image_hidden_states.view(-1, image_hidden_states.shape[-1])
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states
input_ids, inputs_embeds, vision_embeds
)
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
):
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,

View File

@ -477,36 +477,19 @@ class Idefics3ForConditionalGeneration(nn.Module):
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds
def forward(
def get_vision_embeds(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
# Unused here
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None:
batch_size, num_images, num_channels, height, width = pixel_values.shape
all_states = []
all_pixel_values = pixel_values
all_pixel_mask = pixel_attention_mask
for i in range(batch_size):
pixel_values = all_pixel_values.to(
dtype=self.dtype
) # fp16 compatibility
pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility
pixel_values = pixel_values[i : i + 1]
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
@ -538,6 +521,7 @@ class Idefics3ForConditionalGeneration(nn.Module):
].contiguous()
patch_size = self.config.vision_config.patch_size
"""
patches_subgrid = pixel_attention_mask.unfold(
dimension=1, size=patch_size, step=patch_size
@ -558,9 +542,7 @@ class Idefics3ForConditionalGeneration(nn.Module):
conv_kernel,
stride=patch_size,
).squeeze(1)
patch_attention_mask = torch.eq(
patches_subgrid, (patch_size * patch_size)
)
patch_attention_mask = torch.gt(patches_subgrid, 0)
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
@ -576,10 +558,37 @@ class Idefics3ForConditionalGeneration(nn.Module):
all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states
)
return image_hidden_states.view(-1, image_hidden_states.shape[-1])
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, vision_embeds
)
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_indices=None,
):
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,

View File

@ -1,467 +0,0 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Llava-NeXT model."""
from typing import List, Optional, Union
import torch
import torch.utils.checkpoint
import numpy as np
from loguru import logger
from transformers.models.llava_next.modeling_llava_next import (
unpad_image,
)
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
from transformers.image_processing_utils import select_best_resolution
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (`tuple`):
The size of the input image in the format (width, height).
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists")
height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
"""
Calculate the number of patches after the preprocessing for images of any resolution.
Args:
image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`):
The size of the input image in the format (height, width). ?
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
int: the number of patches
"""
if not isinstance(grid_pinpoints, list):
raise TypeError("grid_pinpoints should be a list of tuples or lists")
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
if not isinstance(image_size, (list, tuple)):
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
raise TypeError(
f"image_size invalid type {type(image_size)} with value {image_size}"
)
image_size = image_size.tolist()
best_resolution = select_best_resolution(image_size, grid_pinpoints)
height, width = best_resolution
num_patches = 0
# consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
num_patches += 1
# add the base patch
num_patches += 1
return num_patches
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[int] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = True,
flash_attention_recompute: Optional[bool] = True,
):
if token_idx is not None:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
token_idx=token_idx,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)
logits = outputs[0]
if not return_dict:
output = (logits,) + outputs[1:]
return output
return outputs
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411
def pack_image_features(
self,
image_features,
image_sizes,
vision_feature_select_strategy,
image_newline=None,
):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
Args:
image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
List of image feature tensor, each contains all the visual feature of all patches.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_select_strategy (`str`)
The feature selection strategy used to select the vision feature from the vision backbone.
image_newline (`torch.Tensor` of shape `(embed_dim)`)
New line embedding vector.
Returns:
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
feature_lens (`List[int]`)
token length of each image in image_features
"""
new_image_features = []
feature_lens = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
if (
np.prod(image_feature.shape)
% (num_patch_height * num_patch_width * height * width)
!= 0
and vision_feature_select_strategy == "default"
):
logger.warning_once(
"Image feature shape does not line up with the provided patch size. "
"You may be using the `default` vision_feature_select_strategy with a"
" visual encoder that does not have CLS."
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
if image_newline is not None:
image_feature = torch.cat(
(
image_feature,
image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device, image_feature.dtype),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
if image_newline is not None:
image_feature = torch.cat(
(image_feature, image_newline[None].to(image_feature)), dim=0
)
new_image_features.append(image_feature)
feature_lens.append(image_feature.size(0))
image_features = torch.cat(new_image_features, dim=0)
feature_lens = torch.tensor(
feature_lens, dtype=torch.long, device=image_features.device
)
return image_features, feature_lens
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479
def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
"""
# ! infer image_num_patches from image_sizes
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.config.image_grid_pinpoints,
patch_size=self.config.vision_config.image_size,
)
for imsize in image_sizes
]
if pixel_values.dim() == 5:
# stacked if input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [
pix_val[:num_patch]
for pix_val, num_patch in zip(pixel_values, image_num_patches)
]
pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError(
f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions"
)
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_image_feature = image_features.hidden_states[vision_feature_layer]
else:
hs_pool = [
image_features.hidden_states[layer_idx]
for layer_idx in vision_feature_layer
]
selected_image_feature = torch.cat(hs_pool, dim=-1)
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
image_features = self.multi_modal_projector(selected_image_feature)
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
pixel_values=None,
image_sizes=None,
attention_mask=None,
**kwargs,
):
"""
Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635
The only differences are:
- add new args token_idx
- add the process of merging images into inputs_embeds
"""
token_idx = kwargs.get("token_idx", None)
if token_idx is None:
return super().prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
pixel_values=pixel_values,
image_sizes=image_sizes,
attention_mask=attention_mask,
**kwargs,
)
else:
use_flash_attention = kwargs.get("use_flash_attention", True)
flash_attention_recompute = kwargs.get("flash_attention_recompute", True)
position_ids = kwargs.get("position_ids", None)
labels = kwargs.get("labels", None)
if (
past_key_values is None
and pixel_values is not None
and input_ids.shape[1] != 1
):
vision_feature_select_strategy = kwargs.get(
"vision_feature_select_strategy", None
)
vision_feature_layer = kwargs.get("vision_feature_layer", None)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
vision_feature_layer = (
vision_feature_layer
if vision_feature_layer is not None
else self.config.vision_feature_layer
)
# 1. Extract the input embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text and images
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=self.image_newline,
)
special_image_mask = (
input_ids == self.config.image_token_index
).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
if inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_index).sum()
n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(
special_image_mask, image_features
)
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
# generation with cache
elif past_key_values is not None:
seq_len = input_ids.shape[1]
pad_len = seq_len - token_idx.item()
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(
first_layer_past_key_value.float().sum(-2) == 0
)
# Get the target length
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = extended_attention_mask
attention_mask[:, -pad_len:] = 0
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
if token_idx is not None:
position_ids = (
torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
)
else:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"token_idx": token_idx,
"labels": labels,
"use_flash_attention": use_flash_attention,
"flash_attention_recompute": flash_attention_recompute,
}
)
return model_inputs

View File

@ -1,292 +0,0 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Mllama model."""
from typing import Optional, Tuple, List, Union
import torch
import torch.utils.checkpoint
from optimum.habana.transformers.models import GaudiMllamaForConditionalGeneration
from optimum.habana.transformers.models.mllama.modeling_mllama import (
_prepare_cross_attention_mask,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
class MllamaForConditionalGeneration(GaudiMllamaForConditionalGeneration):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
aspect_ratio_mask: Optional[torch.Tensor] = None,
aspect_ratio_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = True,
flash_attention_recompute: Optional[bool] = True,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Copied from MllamaForConditionalGeneration::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2077
The only differences are:
- add token_idx input
- add use_flash_attention and flash_attention_recompute
"""
full_text_row_masked_out_mask = kwargs.get(
"full_text_row_masked_out_mask", None
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
past_key_values=past_key_values,
use_cache=use_cache,
inputs_embeds=inputs_embeds,
labels=labels,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
token_idx=token_idx,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)
logits = outputs[0]
if not return_dict:
output = (logits,) + outputs[1:]
return output
return outputs
def prepare_inputs_for_generation(
self,
input_ids=None,
inputs_embeds=None,
attention_mask=None,
position_ids=None,
pixel_values=None,
aspect_ratio_ids=None,
aspect_ratio_mask=None,
cross_attention_mask=None,
past_key_values=None,
use_cache=False,
cache_position=None,
num_logits_to_keep=None,
**kwargs,
):
"""
Copied from MllamaForConditionalGeneration::prepare_inputs_for_generation: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2208
The only differences are:
- add token_idx handling
- add bucket_internal handling
- add use_flash_attention and flash_attention_recompute
"""
token_idx = kwargs.get("token_idx", None)
if token_idx is None:
return super().prepare_inputs_for_generation(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
pixel_values=pixel_values,
aspect_ratio_ids=aspect_ratio_ids,
aspect_ratio_mask=aspect_ratio_mask,
cross_attention_mask=cross_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
else:
use_flash_attention = kwargs.get("use_flash_attention", True)
flash_attention_recompute = kwargs.get("flash_attention_recompute", True)
position_ids = kwargs.get("position_ids", None)
output_attentions = kwargs.get("output_attentions", None)
output_hidden_states = kwargs.get("output_hidden_states", None)
return_dict = kwargs.get("return_dict", None)
labels = kwargs.get("labels", None)
cross_attention_states = kwargs.get("cross_attention_states", None)
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
bucket_internal = kwargs.get("bucket_internal", None)
if past_key_values is not None:
if token_idx is not None:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
elif inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif (
input_ids.shape[1] != cache_position.shape[0]
): # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
elif bucket_internal and token_idx is not None:
# for the 1st token we can slice the inputs till token idx for the fwd pass.
input_ids = input_ids[:, :token_idx]
attention_mask = attention_mask[:, :token_idx]
if cross_attention_mask is not None:
cross_attention_mask = cross_attention_mask[:, :token_idx, ...]
# TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
if token_idx is not None:
position_ids = torch.index_select(
position_ids, 1, token_idx - 1
)
else:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(
memory_format=torch.contiguous_format
)
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None and cross_attention_states is not None:
raise ValueError(
"`pixel_values` and `cross_attention_states` cannot be provided simultaneously"
)
if pixel_values is not None:
if aspect_ratio_ids is None:
raise ValueError(
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
)
# get vision tokens from vision model
vision_outputs = self.vision_model(
pixel_values=pixel_values,
aspect_ratio_ids=aspect_ratio_ids,
aspect_ratio_mask=aspect_ratio_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
use_flash_attention=use_flash_attention,
)
cross_attention_states = vision_outputs[0]
cross_attention_states = self.multi_modal_projector(
cross_attention_states
).reshape(-1, cross_attention_states.shape[-2], self.hidden_size)
if cross_attention_mask is not None:
cross_attention_mask, full_text_row_masked_out_mask = (
_prepare_cross_attention_mask(
cross_attention_mask,
num_vision_tokens=self.vision_model.num_patches,
dtype=self.dtype,
token_idx=token_idx,
)
)
else:
full_text_row_masked_out_mask = None
if cross_attention_mask is not None:
if cache_position is not None:
cross_attention_mask = cross_attention_mask[:, :, cache_position]
full_text_row_masked_out_mask = full_text_row_masked_out_mask[
:, :, cache_position
]
elif past_key_values is not None:
if token_idx is not None:
cross_attention_mask = torch.index_select(
cross_attention_mask, -2, token_idx - 1
)
full_text_row_masked_out_mask = torch.index_select(
full_text_row_masked_out_mask, -2, token_idx - 1
)
else:
cross_attention_mask = cross_attention_mask[:, :, -1:]
full_text_row_masked_out_mask = full_text_row_masked_out_mask[
:, :, -1:
]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# The clone here is for the same reason as for `position_ids`.
model_inputs = {
"input_ids": input_ids.clone(memory_format=torch.contiguous_format),
"inputs_embeds": None,
}
if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep
# keep cache_position implementation as None for HPU
cache_position = None
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"token_idx": token_idx,
"labels": labels,
"return_dict": kwargs.get("return_dict"),
"full_text_row_masked_out_mask": full_text_row_masked_out_mask,
"use_flash_attention": use_flash_attention,
"cross_attention_mask": cross_attention_mask,
"cross_attention_states": cross_attention_states,
"output_attentions": output_attentions,
"flash_attention_recompute": flash_attention_recompute,
}
)
return model_inputs

View File

@ -45,11 +45,17 @@ from text_generation_server.layers.attention import (
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2Model,
)
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
apply_rotary_pos_emb,
)
import habana_frameworks.torch as htorch
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py
from typing import Union
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput, VideoInput
from transformers.image_utils import ImageInput
from transformers.video_utils import VideoInput
from transformers.processing_utils import (
ProcessingKwargs,
ProcessorMixin,
@ -375,28 +381,6 @@ class Qwen2_5_VLConfig(PretrainedConfig):
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
tensor: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
orig_dtype = tensor.dtype
tensor = tensor.float()
cos = freqs.cos()
sin = freqs.sin()
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
output = (tensor * cos) + (rotate_half(tensor) * sin)
output = output.to(orig_dtype)
return output
class Qwen2_5VLAttention(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
@ -426,7 +410,8 @@ class Qwen2_5VLAttention(nn.Module):
self,
hidden_state: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
max_seqlen: int,
) -> torch.Tensor:
# apply the qkv linear layer to the hidden state
@ -444,29 +429,37 @@ class Qwen2_5VLAttention(nn.Module):
query = query.view(*_shape)
key = key.view(*_shape)
value = value.view(*_shape)
# apply rotary positional embeddings
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
0
)
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
rotary_dim = cos.shape[-1]
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape))
# calc maximum sequence length for any batch
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
causal = False
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
# execute sdpa
query = query.unsqueeze(0).transpose(1, 2)
key = key.unsqueeze(0).transpose(1, 2)
value = value.unsqueeze(0).transpose(1, 2)
causal = False
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
attention_mask = torch.zeros(
[1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool
)
for i in range(1, len(cu_seqlens)):
attention_mask[
:, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]
] = True
attn_output = fsdpa_op(
query,
key,
value,
attn_mask=None,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=causal,
scale=None,
@ -474,7 +467,7 @@ class Qwen2_5VLAttention(nn.Module):
recompute_mode=None,
valid_sequence_lengths=None,
)
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
attn_output = attn_output.transpose(0, 1)
# reshape output to original dimensions
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
@ -533,11 +526,9 @@ class Qwen2_5VLVisionBlock(nn.Module):
weights=weights,
)
def forward(
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
) -> torch.Tensor:
def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor:
norm1_out, _ = self.norm1(hidden_states)
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen)
hidden_states = hidden_states + attn_out
norm2_out, _ = self.norm2(hidden_states)
mlp_out = self.mlp(norm2_out)
@ -608,7 +599,7 @@ class Qwen2_5VisionModel(nn.Module):
config=config,
weights=weights,
)
# import ipdb; ipdb.set_trace()
self.temporal_patch_size = config.temporal_patch_size
self.spatial_patch_size = config.spatial_patch_size
self.in_channels = config.in_channels
@ -736,6 +727,10 @@ class Qwen2_5VisionModel(nn.Module):
)
rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device)
cos = rotary_pos_emb.cos()
sin = rotary_pos_emb.sin()
cos = torch.cat((cos, cos), dim=-1).unsqueeze(1)
sin = torch.cat((sin, sin), dim=-1).unsqueeze(1)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
@ -754,6 +749,9 @@ class Qwen2_5VisionModel(nn.Module):
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
# iterately apply the blocks to the hidden states
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for layer_num, block in enumerate(self.blocks):
# NOTE: qwen2_5_vl.py has a concept of full attention blocks
# that are applied at specific layers.
@ -762,9 +760,9 @@ class Qwen2_5VisionModel(nn.Module):
else:
cu_seqlens_now = cu_window_seqlens
hidden_states = block(
hidden_states, cu_seqlens_now, rotary_pos_emb, max_seqlen
)
hidden_states = block(hidden_states, cu_seqlens_now, cos, sin, max_seqlen)
if lazy_mode:
htorch.core.mark_step()
# apply the final patch merger to the hidden states
hidden_states = self.merger(hidden_states)
@ -886,9 +884,6 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
full_llm_pos_ids_list = [
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
]
# import ipdb
# ipdb.set_trace()
max_s = full_llm_pos_ids_list[-1].max() + 1
final_text_len = input_ids_len - vision_ends[-1]
if final_text_len > 0:
@ -900,9 +895,33 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
)
return position_ids
def forward(
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
return image_embeds
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if vision_embeds is not None:
mask = torch.where(input_ids == self.image_token_id)
inputs_embeds[mask] = vision_embeds
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -910,26 +929,10 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor],
pixel_values: torch.FloatTensor = None,
image_grid_thw: Optional[torch.LongTensor] = None,
# Unused in this model
video_grid_thw: Optional[torch.LongTensor] = None,
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0:
if pixel_values is not None:
image_embeds = self.visual(
pixel_values, grid_thw=image_grid_thw
).squeeze(0)
mask = torch.where(input_ids == self.image_token_id)
inputs_embeds[mask] = image_embeds
hidden_states = self.text_model(
inputs_embeds=inputs_embeds,

View File

@ -44,28 +44,11 @@ from text_generation_server.layers.attention import (
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2Model,
)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
tensor: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
orig_dtype = tensor.dtype
tensor = tensor.float()
cos = freqs.cos()
sin = freqs.sin()
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
output = (tensor * cos) + (rotate_half(tensor) * sin)
output = output.to(orig_dtype)
return output
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
apply_rotary_pos_emb,
)
import habana_frameworks.torch as htorch
class Qwen2VLAttention(nn.Module):
@ -96,7 +79,8 @@ class Qwen2VLAttention(nn.Module):
self,
hidden_state: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
max_seqlen: int,
) -> torch.Tensor:
# apply the qkv linear layer to the hidden state
@ -116,27 +100,36 @@ class Qwen2VLAttention(nn.Module):
value = value.view(*_shape)
# apply rotary positional embeddings
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
0
)
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
rotary_dim = cos.shape[-1]
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape))
# calc maximum sequence length for any batch
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
causal = False
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
# execute sdpa
query = query.unsqueeze(0).transpose(1, 2)
key = key.unsqueeze(0).transpose(1, 2)
value = value.unsqueeze(0).transpose(1, 2)
causal = False
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
attention_mask = torch.zeros(
[1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool
)
for i in range(1, len(cu_seqlens)):
attention_mask[
:, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]
] = True
attn_output = fsdpa_op(
query,
key,
value,
attn_mask=None,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=causal,
scale=None,
@ -144,7 +137,7 @@ class Qwen2VLAttention(nn.Module):
recompute_mode=None,
valid_sequence_lengths=None,
)
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
attn_output = attn_output.transpose(0, 1)
# reshape output to original dimensions
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
attn_output = self.proj(attn_output)
@ -193,11 +186,9 @@ class Qwen2VLVisionBlock(nn.Module):
weights=weights,
)
def forward(
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
) -> torch.Tensor:
def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor:
norm1_out, residual = self.norm1(hidden_states)
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen)
hidden_states = attn_out + residual
norm2_out, residual = self.norm2(hidden_states)
hidden_states = hidden_states + self.mlp(norm2_out)
@ -330,6 +321,11 @@ class Qwen2VisionModel(nn.Module):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype)
cos = rotary_pos_emb.cos()
sin = rotary_pos_emb.sin()
cos = torch.cat((cos, cos), dim=-1).unsqueeze(1)
sin = torch.cat((sin, sin), dim=-1).unsqueeze(1)
# create a cu_seqlens tensor to be used in the attention mask
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
@ -337,8 +333,13 @@ class Qwen2VisionModel(nn.Module):
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
# iterately apply the blocks to the hidden states
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for block in self.blocks:
hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen)
hidden_states = block(hidden_states, cu_seqlens, cos, sin, max_seqlen)
if lazy_mode:
htorch.core.mark_step()
# apply the final patch merger to the hidden states
hidden_states = self.merger(hidden_states)
@ -474,9 +475,33 @@ class Qwen2VLForConditionalGeneration(nn.Module):
)
return position_ids
def forward(
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
return image_embeds
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if vision_embeds is not None:
mask = torch.where(input_ids == self.image_token_id)
inputs_embeds[mask] = vision_embeds
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -484,26 +509,10 @@ class Qwen2VLForConditionalGeneration(nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor],
pixel_values: torch.FloatTensor = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0:
if pixel_values is not None:
image_embeds = self.visual(
pixel_values, grid_thw=image_grid_thw
).squeeze(0)
mask = torch.where(input_ids == self.image_token_id)
inputs_embeds[mask] = image_embeds
hidden_states = self.text_model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,

View File

@ -53,6 +53,7 @@ from text_generation_server.models.globals import (
)
from text_generation_server.layers.attention import (
KVCache,
KVCompressCache,
Seqlen,
HPUPagedAttentionMetadata,
trim_attn_metadata,
@ -68,11 +69,14 @@ from text_generation_server.utils.import_utils import (
synchronize,
get_free_memory,
)
from text_generation_server.utils.prefill_chunking import (
get_max_prefill_tokens,
)
import vllm_hpu_extension.environment as environment
import habana_frameworks.torch as htorch
import itertools
from vllm_hpu_extension.bucketing import HPUBucketingContext
from vllm_hpu_extension.bucketing.common import get_bucketing_context
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
tracer = trace.get_tracer(__name__)
@ -149,19 +153,14 @@ def prepare_for_decode(
block_list_device = _async_h2d_tensor_copy(block_list)
block_groups_device = _async_h2d_tensor_copy(block_groups)
block_usage_device = _async_h2d_tensor_copy(block_usage)
block_mapping = torch.nn.functional.one_hot(
block_groups_device, num_classes=batch_size
)
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
mask = mask >= block_usage.unsqueeze(-1)
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
return trim_attn_metadata(
HPUPagedAttentionMetadata(
block_list=block_list_device,
block_groups=block_groups_device,
block_usage=block_usage_device,
block_mapping=block_mapping.to(dtype),
attn_bias=attn_bias,
block_mapping=None,
attn_bias=None,
)
)
@ -424,7 +423,7 @@ class FlashCausalLMBatch(Batch):
for i, input_ids in enumerate(all_input_ids):
all_input_ids_tensor[i, : len(input_ids)] = input_ids
# Create tensors on device
# put on cpu temporarily, move to hpu in prepare_for_prefill
all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64)
top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64)
@ -628,14 +627,16 @@ class FlashCausalLMBatch(Batch):
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
adapter_indices = self.adapter_meta.adapter_indices[indices]
input_lengths_tensor = self.input_lengths_tensor[indices]
cache_lengths_tensor = self.cache_lengths_tensor[indices]
# Move to GPU now that we have the whole tensor
slot_indices = slot_indices.to(device)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
if self.adapter_meta is not None:
adapter_indices = self.adapter_meta.adapter_indices[indices]
adapter_segments, adapter_segment_indices = find_segments(
adapter_indices
)
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
@ -643,6 +644,8 @@ class FlashCausalLMBatch(Batch):
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
)
else:
adapter_meta = None
htorch.core.mark_step()
return type(self)(
batch_id=self.batch_id,
@ -691,7 +694,9 @@ class FlashCausalLMBatch(Batch):
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
def concatenate(
cls, batches: List["FlashCausalLMBatch"], padded_total_bs: int = 0
) -> "FlashCausalLMBatch":
# Batch attributes
requests = []
requests_idx_mapping = {}
@ -704,6 +709,7 @@ class FlashCausalLMBatch(Batch):
max_length = 0
max_input_length = 0
max_current_length = 0
ADAPTER_TO_INDEX = get_adapter_to_index()
for b in batches:
total_batch_size += len(b)
max_blocks = max(max_blocks, b.max_blocks)
@ -738,6 +744,9 @@ class FlashCausalLMBatch(Batch):
input_lengths_tensor = None
adapter_meta = None
adapter_segment_builder = None
else:
if padded_total_bs == batches[0].input_ids.shape[0]:
input_ids = batches[0].input_ids
else:
input_ids = batches[0].input_ids.new_empty(total_batch_size)
if (
@ -757,6 +766,7 @@ class FlashCausalLMBatch(Batch):
cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(
total_batch_size
)
if ADAPTER_TO_INDEX:
total_indices_size = sum(
b.adapter_meta.adapter_indices.shape[0] for b in batches
)
@ -772,9 +782,7 @@ class FlashCausalLMBatch(Batch):
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
(total_batch_size, max_blocks)
)
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
(total_batch_size, max_length)
)
all_input_ids_tensor = batches[0].all_input_ids_tensor
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
@ -815,13 +823,14 @@ class FlashCausalLMBatch(Batch):
start_index = cumulative_batch_size
end_index = cumulative_batch_size + valid_bsize
index = torch.tensor(
list(range(start_index, end_index)), device=batch.input_ids.device
)
index = torch.tensor(list(range(start_index, end_index)), device="cpu")
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor[:valid_bsize, :max_length]
if i > 0:
all_input_ids_tensor.index_copy_(
0,
index.to(batch.all_input_ids_tensor.device),
batch.all_input_ids_tensor[:valid_bsize, :],
)
block_tables_tensor[
start_index:end_index, : batch.block_tables_tensor.shape[1]
@ -841,7 +850,10 @@ class FlashCausalLMBatch(Batch):
)
if not prefilling:
input_ids.index_copy_(0, index, batch.input_ids[:valid_bsize])
if padded_total_bs != batches[0].input_ids.shape[0] or i > 0:
input_ids.index_copy_(
0, index.to(input_ids.device), batch.input_ids[:valid_bsize]
)
position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize])
slot_indices.index_copy_(
0, index, batch.slot_indices + cumulative_slots
@ -852,6 +864,7 @@ class FlashCausalLMBatch(Batch):
cache_lengths_tensor.index_copy_(
0, index, batch.cache_lengths_tensor[:valid_bsize]
)
if ADAPTER_TO_INDEX:
adapter_start_index = cumulative_adapter_indices_size
adapter_end_index = (
cumulative_adapter_indices_size
@ -908,7 +921,7 @@ class FlashCausalLMBatch(Batch):
else:
speculative_ids = None
if adapter_segment_builder is not None:
if ADAPTER_TO_INDEX and adapter_segment_builder is not None:
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
@ -955,7 +968,7 @@ class FlashCausalLMBatch(Batch):
num_blocks=num_blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids,
adapter_meta=adapter_meta,
adapter_meta=adapter_meta if ADAPTER_TO_INDEX else None,
hpu_attn_meta=None,
next_token_logits=None,
speculative_logits=None,
@ -974,7 +987,6 @@ class FlashCausalLMBatch(Batch):
else:
padded_bs = self.input_ids.shape[0]
slots = self.slots[self.slot_indices]
extra_pad = padded_bs - self.input_ids.shape[0]
self.hpu_attn_meta = prepare_for_decode(
dtype,
@ -985,17 +997,29 @@ class FlashCausalLMBatch(Batch):
padded_bs,
bucketing_ctx,
)
self.input_ids = F.pad(self.input_ids, (0, extra_pad), value=0)
self.position_ids = F.pad(self.position_ids, (0, extra_pad), value=1)
self.input_ids = F.pad(
self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=0
)
if self.position_ids.dim() == 2:
# Qwen VL case
self.position_ids = F.pad(
self.position_ids,
(0, 0, 0, padded_bs - self.position_ids.shape[0]),
value=1,
)
else:
self.position_ids = F.pad(
self.position_ids, (0, padded_bs - self.position_ids.shape[0]), value=1
)
self.input_lengths_tensor = F.pad(
self.input_lengths_tensor, (0, extra_pad), value=0
self.input_lengths_tensor,
(0, padded_bs - self.input_lengths_tensor.shape[0]),
value=0,
)
self.cache_lengths_tensor = F.pad(
self.cache_lengths_tensor, (0, extra_pad), value=0
)
self.all_input_ids_tensor = F.pad(
self.all_input_ids_tensor,
(0, 0, 0, extra_pad),
self.cache_lengths_tensor,
(0, padded_bs - self.cache_lengths_tensor.shape[0]),
value=0,
)
next_token_chooser_parameters = []
@ -1015,7 +1039,9 @@ class FlashCausalLMBatch(Batch):
fsm_grammar_states,
)
def prepare_for_prefill(self, max_padded_input_len, max_padded_bs):
def prepare_for_prefill(
self, max_padded_input_len, max_padded_bs, max_total_tokens
):
# Prepare values if we need to continue prefilling
# Speculation must be ignored while we prefill even with chunking
# it simplifies everything
@ -1031,6 +1057,7 @@ class FlashCausalLMBatch(Batch):
# need extra pad to match warmup seq
extra_pad = max_padded_input_len - self.max_input_length
extra_pad_bs = max_padded_bs - len(self)
device = "hpu"
if isinstance(self.input_ids, list) and len(self) > 1:
input_ids_padded_length = []
input_ids = []
@ -1041,15 +1068,26 @@ class FlashCausalLMBatch(Batch):
input_ids.append(input_id)
input_ids_padded_length.append(padded)
input_ids = np.concatenate(input_ids, dtype=np.int64)
self.input_ids = torch.tensor(input_ids, dtype=torch.int64)
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
elif isinstance(self.input_ids, list):
input_ids = self.input_ids[0]
input_ids_padded_length.append(extra_pad)
input_ids = [0] * extra_pad + input_ids
self.input_ids = torch.tensor(input_ids, dtype=torch.int64)
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
else:
self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0)
input_ids_padded_length.extend([extra_pad] * len(self))
input_ids = self.input_ids.new_zeros(max_padded_input_len * len(self))
src_pos = 0
for i in range(len(self)):
end_pos = (i + 1) * max_padded_input_len
start_pos = end_pos - self.input_lengths[i]
input_ids[start_pos:end_pos] = self.input_ids[
src_pos : src_pos + self.input_lengths[i]
]
input_ids_padded_length.append(
max_padded_input_len - self.input_lengths[i]
)
src_pos += self.input_lengths[i]
self.input_ids = input_ids
self.input_ids = F.pad(
self.input_ids, (0, extra_pad_bs * max_padded_input_len), value=0
@ -1239,7 +1277,9 @@ class FlashCausalLMBatch(Batch):
self.slot_indices = slot_indices
self.prefill_cu_outlens = prefill_cu_outlens
self.prefill_cache_indices = torch.zeros_like(self.input_ids, dtype=torch.bool)
self.prefill_cache_indices = torch.zeros_like(
self.input_ids, dtype=torch.bool, device="cpu"
)
self.prefill_cache_indices[prefill_cache_indices] = True
if all_prefill_logprobs:
@ -1272,12 +1312,17 @@ class FlashCausalLMBatch(Batch):
self.prefill_next_token_indices = (
self.prefill_next_token_indices + input_ids_padded_length_tensor
)
self.all_input_ids_tensor = F.pad(
self.all_input_ids_tensor,
(0, 0, 0, extra_pad_bs),
value=0,
all_input_ids_tensor = torch.zeros(
(max_padded_bs, max(max_total_tokens, self.all_input_ids_tensor.shape[-1])),
dtype=torch.int64,
device="hpu",
)
for i in range(len(self)):
all_input_ids_tensor[i, : self.all_input_ids_tensor.shape[-1]] = (
self.all_input_ids_tensor[i]
)
self.all_input_ids_tensor = all_input_ids_tensor
next_token_chooser_parameters = []
next_token_chooser_parameters.extend([r.parameters for r in self.requests])
pad_next_token_chooser_parameters(next_token_chooser_parameters, max_padded_bs)
@ -1295,9 +1340,12 @@ class FlashCausalLMBatch(Batch):
fsm_grammar_states,
)
if ADAPTER_TO_INDEX:
if adapter_set:
adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments, adapter_segment_indices = find_segments(
adapter_indices
)
else:
adapter_indices = torch.zeros_like(self.input_ids)
adapter_segments = [0, len(adapter_indices)]
@ -1352,6 +1400,8 @@ class FlashCausalLM(Model):
):
self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed()
if world_size > 1:
self.process_group_cpu = torch.distributed.new_group(backend="gloo")
device = torch.device("hpu")
dtype = torch.bfloat16 if dtype is None else dtype
@ -1419,7 +1469,7 @@ class FlashCausalLM(Model):
raise ValueError("Cannot get the number of key/value heads")
self.num_kv_heads = (
num_kv_heads // self.process_group.size()
if num_kv_heads > 1
if num_kv_heads // self.process_group.size() > 0
else num_kv_heads
)
assert self.num_kv_heads > 0
@ -1427,7 +1477,7 @@ class FlashCausalLM(Model):
if head_size is None:
# Some models use GQA and different sizes for o_proj
# and q_proj, that allows for that.
if hasattr(config, "head_dim"):
if getattr(config, "head_dim", None) is not None:
self.head_size = config.head_dim
else:
self.head_size = config.hidden_size // config.num_attention_heads
@ -1438,15 +1488,20 @@ class FlashCausalLM(Model):
self.kv_cache = []
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
self.bucketing_ctx = None
self.max_total_tokens = None
self.max_input_tokens = None
htorch.core.hpu_set_env()
if htorch.utils.internal.is_lazy():
htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True)
environment.set_model_config(self.config)
self.use_contiguous_pa = (
os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true"
)
self.limit_hpu_graphs = (
os.environ.get("LIMIT_HPU_GRAPHS", "false").lower() == "true"
self.limit_hpu_graph = (
os.environ.get("LIMIT_HPU_GRAPH", "false").lower() == "true"
)
self.skip_warmup = os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true"
self.max_seq_len_to_capture = 8192
super().__init__(
model_id=model_id,
model=model,
@ -1478,6 +1533,17 @@ class FlashCausalLM(Model):
):
self.kv_cache = []
empty_cache()
if self.config.model_type == "deepseek_v3":
self.kv_cache = [
KVCompressCache(
num_blocks=num_blocks,
head_size=self.config.kv_lora_rank + self.config.qk_rope_head_dim,
dtype=dtype,
device=device,
)
for _ in range(num_layers)
]
else:
self.kv_cache = [
KVCache(
num_blocks=num_blocks,
@ -1495,16 +1561,48 @@ class FlashCausalLM(Model):
max_input_tokens: Optional[int],
max_total_tokens: Optional[int],
):
if os.environ.get("MAX_BATCH_SIZE") is None:
raise RuntimeError(
"MAX_BATCH_SIZE is not set, it should be set in the launcher "
"using `--max-batch-size xxx`"
)
# The warmup batch is the biggest batch we could ever receive
self.kv_cache = []
empty_cache()
self.graphed_buckets = set()
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
if self.config.model_type == "deepseek_v3":
cache_block_size = BLOCK_SIZE * (
self.config.kv_lora_rank + self.config.qk_rope_head_dim
)
else:
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
cache_block_size = cache_block_size * 2
total_cache_size = self.num_layers * cache_block_size * dtype_size
free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM)
self.mem_reserved = int(free_memory * (1 - MEMORY_FRACTION))
graph_reserved_mem = (
float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1"))
if htorch.utils.internal.is_lazy()
else 0
)
mem_used_from_graph = int(
(free_memory - self.mem_reserved) * graph_reserved_mem
)
log_master(
logger.info,
f"Free memory on device {self.device}: {format_bytes(free_memory)} used_for_graph: {format_bytes(mem_used_from_graph)} ratio {graph_reserved_mem} reserved_for_runtime: {format_bytes(self.mem_reserved)}",
)
if max_total_tokens is None:
max_total_tokens = sum(batch.input_lengths)
if max_input_tokens is None:
max_input_tokens = max_total_tokens - 1
self.max_total_tokens = max_total_tokens
self.max_input_tokens = max_input_tokens
try:
self.init_kv_cache(
batch.num_blocks,
@ -1519,15 +1617,6 @@ class FlashCausalLM(Model):
num_tokens = batch.to_pb().current_tokens
synchronize(self.device)
free_memory = get_free_memory(
self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM
)
real_free_memory = get_free_memory(self.device, MEMORY_FRACTION)
log_master(
logger.debug,
f"Free memory {free_memory / 1e9:.2f}GB , (real: {real_free_memory / 1e9:.2f}GB",
)
_, _batch, _ = self.generate_token([batch])
except Exception:
raise RuntimeError(
@ -1536,8 +1625,9 @@ class FlashCausalLM(Model):
)
synchronize(self.device)
free_memory = get_free_memory(self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM)
kv_memory = free_memory
free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM)
kv_memory = free_memory - self.mem_reserved - mem_used_from_graph
num_blocks = (
# Leave 5% for some wiggle room
int(kv_memory // total_cache_size)
@ -1546,15 +1636,9 @@ class FlashCausalLM(Model):
)
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
if max_total_tokens is None:
max_total_tokens = sum(batch.input_lengths)
if max_input_tokens is None:
max_input_tokens = max_total_tokens - 1
self.kv_cache = []
empty_cache()
self.init_kv_cache(
num_blocks,
self.num_layers,
@ -1563,56 +1647,177 @@ class FlashCausalLM(Model):
self.kv_cache_dtype,
self.device,
)
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE", 128))
if os.getenv("VLLM_PROMPT_SEQ_BUCKET_MAX") is None:
os.environ["VLLM_PROMPT_SEQ_BUCKET_MAX"] = str(max_input_tokens)
if os.getenv("VLLM_DECODE_BLOCK_BUCKET_MAX") is None:
max_total_blocks = (
math.ceil(max_total_tokens / BLOCK_SIZE) * max_num_seqs + 1
self.max_batch_prefill_tokens = get_max_prefill_tokens()
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE"))
HPUBucketingContext = get_bucketing_context()
# need to warmup one more step since block is allocated from 1
block_step = os.getenv("VLLM_DECODE_BLOCK_BUCKET_STEP", BLOCK_SIZE)
max_total_tokens_aligned = math.ceil(
max_total_tokens / BLOCK_SIZE
) * BLOCK_SIZE + math.ceil(block_step * BLOCK_SIZE / max_num_seqs)
model_max_length = self.tokenizer.model_max_length
max_position_embeddings = getattr(
self.config, "max_position_embeddings", model_max_length
)
os.environ["VLLM_DECODE_BLOCK_BUCKET_MAX"] = str(max_total_blocks)
self.bucketing_ctx = HPUBucketingContext(
max_num_seqs,
os.getenv("PREFILL_MAX_BS", 64), # self.max_num_prefill_seqs, #TODO
max_num_seqs, # self.max_num_prefill_seqs, #TODO
BLOCK_SIZE,
num_blocks * BLOCK_SIZE,
max_num_seqs * max_total_tokens_aligned,
False,
min(model_max_length, max_position_embeddings),
max_input_tokens,
max_total_tokens_aligned,
)
max_blocks = max(
BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE
)
self.bucketing_ctx.num_hpu_blocks = min(max_blocks, num_blocks)
synchronize(self.device)
if self.skip_warmup:
self.bucketing_ctx.generate_prompt_buckets()
self.bucketing_ctx.generate_decode_buckets(
self.bucketing_ctx.num_hpu_blocks
)
log_master(
logger.info, "skip warmup hpu graph, not recommmended, may cause OOM"
)
self.bucketing_ctx.num_hpu_blocks = num_blocks
if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":
logger.info("skip warmup hpu graph, not recommmended")
del _batch, batch
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
self.warmup_hpu_graph(batch)
del _batch, batch
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
def log_warmup(self, prefilling, i, max_i, batch_size, seq_len):
free_mem = format_bytes(HabanaMemoryProfiler.current_free_device_memory())
phase = "Prompt" if prefilling else "Decode"
dim = "seq_len" if prefilling else "num_blocks"
graphed_bucket = (batch_size, seq_len, prefilling)
bypass = graphed_bucket not in self.graphed_buckets
msg = (
f"[Warmup][{phase}][{i+1}/{max_i}] "
f"batch_size:{batch_size} "
f"{dim}:{seq_len} "
f"bypass:{bypass} "
f"free_mem:{free_mem}"
)
log_master(logger.info, msg)
def use_graphs(self, prefill, seq_len, batch_size):
if self.limit_hpu_graph and prefill:
return False
if self.skip_warmup:
return True
return (batch_size, seq_len, prefill) in self.graphed_buckets
def align_workers(self, value, op):
if self.world_size <= 1:
return value
value_t = torch.tensor(value, device="cpu")
torch.distributed.all_reduce(value_t, op=op, group=self.process_group_cpu)
return value_t.item()
def warmup_hpu_graph(self, batch):
prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3"))
free_mem = HabanaMemoryProfiler.current_free_device_memory()
graph_free_mem = free_mem - self.mem_reserved
graph_free_mem = self.align_workers(
graph_free_mem, torch.distributed.ReduceOp.MIN
)
prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
decode_available_memory = graph_free_mem - prompt_available_memory
msg = (
f"Using {format_bytes(graph_free_mem)}"
f"/{format_bytes(free_mem)} "
"of free device memory for HPUGraphs, "
f"{format_bytes(prompt_available_memory)} for prompt and "
f"{format_bytes(decode_available_memory)} for decode "
f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})"
)
log_master(logger.info, msg)
start_time = time.time()
warmup_shape_count = 0
warmup_times = 3
self.bucketing_ctx.generate_prompt_buckets()
for i, (batch_size, seq_len) in enumerate(
reversed(self.bucketing_ctx.prompt_buckets)
def ordering_function_min_tokens(b):
return (b[0] * b[1], b[1], b[0])
buckets = list(
sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
)
total_batch_seq = 0.001
total_mem = 0
available_mem = prompt_available_memory
for i, (batch_size, seq_len) in enumerate(buckets):
if batch_size * seq_len > self.max_batch_prefill_tokens:
continue
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size * seq_len
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, seq_len, True)
if not (
mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture
):
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
self.log_warmup(True, i, len(buckets), batch_size, seq_len)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
)
if graphed_bucket in self.graphed_buckets:
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
def ordering_function_max_bs(b):
return (-b[0], b[1])
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets)
):
buckets = list(
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
)
free_mem = HabanaMemoryProfiler.current_free_device_memory()
total_batch_seq = 0.001
total_mem = 0
available_mem = free_mem - self.mem_reserved
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num:
continue
log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
)
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, block_num, False)
if not mem_estimate >= available_mem:
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
self.log_warmup(False, i, len(buckets), batch_size, block_num)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
)
if graphed_bucket in self.graphed_buckets:
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
)
def warmup_prefill(
self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch
@ -1643,7 +1848,9 @@ class FlashCausalLM(Model):
lm_head_indices = input_lengths - 1
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
True, prompt_len, batch_size
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
@ -1696,7 +1903,9 @@ class FlashCausalLM(Model):
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = False
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
False, hpu_attention_meta.block_list.shape[0], batch_size
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids),
@ -1779,11 +1988,11 @@ class FlashCausalLM(Model):
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids)
slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
else:
slots_pad = torch.zeros_like(input_ids)
slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[: slots.shape[0]] = slots
slots = slots_pad
seqlen = Seqlen(
@ -1792,12 +2001,18 @@ class FlashCausalLM(Model):
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = (
batch.prefilling if self.limit_hpu_graphs else False
batch_size = input_lengths.shape[0]
prompt_len = (
input_ids.shape[0] // batch_size
if batch.prefilling
else batch.hpu_attn_meta.block_list.shape[0]
)
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
batch.prefilling, prompt_len, batch_size
)
logits, speculative_logits = self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids),
input_ids=input_ids,
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=kv_cache,
@ -1836,9 +2051,9 @@ class FlashCausalLM(Model):
accepted_ids,
speculative_ids,
) = batch.next_token_chooser(
_async_h2d_tensor_copy(
batch.all_input_ids_tensor[:, : batch.max_current_length]
),
batch.all_input_ids_tensor[
: batch.next_token_logits.shape[0], : batch.max_current_length
],
batch.next_token_logits,
speculate,
batch.speculative_ids,
@ -1852,15 +2067,29 @@ class FlashCausalLM(Model):
accepted_ids,
)
if batch.valid_indices is not None:
next_input_ids = next_input_ids.cpu()
next_token_logprobs = next_token_logprobs.cpu()
accepted_ids = accepted_ids.cpu()
batch.all_input_ids_tensor = batch.all_input_ids_tensor[
batch.valid_indices
]
next_input_ids = next_input_ids[batch.valid_indices]
next_token_logprobs = next_token_logprobs[batch.valid_indices]
accepted_ids = accepted_ids[batch.valid_indices]
# TODO speculative decoding handling missing
index = torch.arange(
0,
len(batch.valid_indices),
device=batch.all_input_ids_tensor.device,
)
batch.all_input_ids_tensor.index_copy_(
0, index, batch.all_input_ids_tensor[batch.valid_indices]
)
padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size(
len(batch.valid_indices)
)
next_input_ids.index_copy_(
0, index, next_input_ids[batch.valid_indices]
)
next_input_ids = next_input_ids[:padded_total_bs]
next_token_logprobs.index_copy_(
0, index, next_token_logprobs[batch.valid_indices]
)
accepted_ids.index_copy_(
0, index, accepted_ids[batch.valid_indices]
)
if speculative_ids is not None:
speculative_ids = speculative_ids[batch.valid_indices]
batch.top_n_tokens_tensor = batch.top_n_tokens_tensor[
@ -1894,16 +2123,16 @@ class FlashCausalLM(Model):
batch.position_ids = batch.position_ids[indices]
batch.slot_indices = batch.slot_indices[indices[: len(batch)]]
if batch.adapter_meta is not None:
batch.adapter_meta.adapter_indices = (
batch.adapter_meta.adapter_indices[indices]
)
# For each member of the batch
# Cumulative length
accepted_ids = accepted_ids.cpu()
if batch.speculative_logits is not None:
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
next_input_ids = next_input_ids.cpu()
if batch.speculative_logits is not None:
for i in range(len(batch)):
batch.all_input_ids_tensor[
i,
@ -1912,20 +2141,8 @@ class FlashCausalLM(Model):
+ batch.input_lengths[i]
+ accepted_ids[i],
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
else:
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
index = index.to(batch.all_input_ids_tensor)
batch_idx = torch.arange(
0,
batch.all_input_ids_tensor.shape[0],
dtype=torch.long,
device=batch.all_input_ids_tensor.device,
)
batch.all_input_ids_tensor.index_put_(
(batch_idx, index.long()), next_input_ids
)
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
batch.speculative_ids = speculative_ids
accepted_ids = accepted_ids.cpu()
if batch.position_ids.dim() == 2:
# Qwen2_vl case:
batch.position_ids += accepted_ids.unsqueeze(-1)
@ -1934,11 +2151,37 @@ class FlashCausalLM(Model):
batch.cache_lengths_tensor += (
batch.input_lengths_tensor + accepted_ids - 1
)
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
batch.input_lengths_tensor = torch.ones_like(
batch.input_lengths_tensor
)
batch.slot_indices += accepted_ids[: len(batch)]
else:
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
index = F.pad(
index, (0, next_input_ids.shape[0] - index.shape[0]), value=0
)
index = index.to(batch.all_input_ids_tensor.device)
batch_idx = torch.arange(
0,
index.shape[0],
dtype=torch.long,
device=batch.all_input_ids_tensor.device,
)
batch.all_input_ids_tensor.index_put_(
(batch_idx, index.long()), next_input_ids
)
batch.input_ids = next_input_ids
batch.position_ids += 1
batch.cache_lengths_tensor += batch.input_lengths_tensor
batch.input_lengths_tensor = torch.ones_like(
batch.input_lengths_tensor
)
batch.slot_indices += 1
batch.speculative_ids = speculative_ids
# Does a HPU <-> CPU sync internally
if prefill:
if prefill and batch.adapter_meta is not None:
# adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments(
batch.adapter_meta.adapter_indices
@ -2008,6 +2251,17 @@ class FlashCausalLM(Model):
htorch.core.mark_step()
# Stage 2. Prepare new batch for speculative scheduling
if len(batches) > 1:
if self.bucketing_ctx is not None:
total_batch_size = 0
for b in batches:
total_batch_size += len(b)
padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size(
total_batch_size
)
batch = self.batch_type.concatenate(
batches, padded_total_bs=padded_total_bs
)
else:
batch = self.batch_type.concatenate(batches)
else:
batch = batches[0]
@ -2019,16 +2273,22 @@ class FlashCausalLM(Model):
batch.max_input_length
),
self.bucketing_ctx.get_padded_prompt_batch_size(len(batch)),
self.max_total_tokens,
)
else:
batch.prepare_for_prefill(batch.max_input_length, len(batch))
batch.prepare_for_prefill(
batch.max_input_length, len(batch), self.max_total_tokens
)
else:
batch.prepare_for_decode(
self.dtype, self.use_contiguous_pa, self.bucketing_ctx
)
if hasattr(self, "set_inputs_embeds") and callable(self.set_inputs_embeds):
self.set_inputs_embeds(batch)
prefill_logprobs = batch.prefill_next_token_indices is not None
# Update adapter indices for speculative tokens (if present)
adapter_meta = batch.adapter_meta
if adapter_meta is not None:
if batch.speculative_ids is not None:
B, speculative_length = batch.speculative_ids.shape
new_length = speculative_length + 1
@ -2053,6 +2313,8 @@ class FlashCausalLM(Model):
prefill,
batch.prefill_head_indices,
)
else:
adapter_data = None
out, speculative_logits = self.forward(batch, adapter_data)

View File

@ -1,7 +1,7 @@
import torch
from PIL import Image
from io import BytesIO
from dataclasses import dataclass
from opentelemetry import trace
from typing import Iterable, Optional, Tuple, List, Type, Dict
@ -23,9 +23,11 @@ from text_generation_server.layers.attention import (
_async_h2d_tensor_copy,
)
import habana_frameworks.torch as htorch
import time
from text_generation_server.utils.import_utils import (
synchronize,
)
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
tracer = trace.get_tracer(__name__)
@ -37,6 +39,33 @@ IDEFICS3_FAKE_IMAGE_TOKEN = "<fake_token_around_image>"
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"
def prompt_split_image_llama4(aspect_ratio, num_patches_per_chunk):
"""
Create a structured string representation of image tokens
Args:
num_patches: Number of patches in the image
Returns:
String with appropriate image tokens
"""
img_string = "<|image_start|>"
ratio_h, ratio_w = aspect_ratio
if ratio_h * ratio_w > 1:
for yy in range(ratio_h):
for xx in range(ratio_w):
img_string += "<|patch|>" * num_patches_per_chunk
if xx < ratio_w - 1:
img_string += "<|tile_x_separator|>"
img_string += "<|tile_y_separator|>"
img_string += "<|image|>"
img_string += "<|patch|>" * num_patches_per_chunk
img_string += "<|image_end|>"
return img_string
# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60
def _prompt_split_image(
*,
@ -90,17 +119,17 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return height // patch_size, width // patch_size
def image_text_replacement(processor, image_input, config, image_id: int) -> str:
def image_text_replacement(processor, image_input, config) -> str:
if config.model_type == "idefics2":
image_seq_len = 64
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
if processor.image_processor.do_image_splitting:
image_str *= 5
return image_str
return image_str, IDEFICS2_FAKE_TOKEN
if config.model_type == "idefics3":
# TODO: implement this in a more general way
n_rows = image_input["rows"][0][image_id]
n_cols = image_input["cols"][0][image_id]
n_rows = image_input["rows"][0][0]
n_cols = image_input["cols"][0][0]
image_seq_len = int(
((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
/ (config.scale_factor**2)
@ -113,35 +142,52 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
image_token=IDEFICS3_IMAGE_TOKEN,
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
)
return image_str
return image_str, IDEFICS3_FAKE_IMAGE_TOKEN
elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id]
height, width = image_input["image_sizes"][0]
num_features = get_number_of_features(height, width, config)
log_master(
logger.info,
f"Found {num_features} features in image of resolution {height}x{width}",
)
return "<image>" * num_features
return "<image>" * num_features, "<image>"
elif config.model_type == "paligemma":
return "<image>" * config.text_config.num_image_tokens
return "<image>" * config.text_config.num_image_tokens, "<image>"
elif config.model_type == "qwen2_vl":
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
grid_t, grid_h, grid_w = image_input["image_grid_thw"][0]
num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>"
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
elif config.model_type == "qwen2_5_vl":
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
grid_t, grid_h, grid_w = image_input["image_grid_thw"][0]
num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>"
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
elif config.model_type == "gemma3":
# TODO: get correct number of features via reviewing the Gemma3 architecture
# and calculating the number of image tokens
num_pads = 256
padding = "<image_soft_token>" * num_pads
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n"
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n", "<start_of_image>"
elif config.model_type == "llama4":
patch_size = config.vision_config.patch_size
pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio
downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
aspect_ratios = image_input["aspect_ratios"][0]
image_height, image_width = image_input["pixel_values"][0].shape[-2:]
num_patches_per_chunk = int(
(image_height // patch_size)
* (image_width // patch_size)
// downsample_ratio
)
tokens_for_this_image = prompt_split_image_llama4(
aspect_ratios, num_patches_per_chunk
)
return tokens_for_this_image, "<|image_start|>"
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
@ -154,6 +200,27 @@ def image_text_replacement_fixup(config, text: str) -> str:
return text
def preprocess_text(config, text: str) -> str:
if config.model_type == "paligemma":
return "<bos>" + text + "\n"
return text
def preprocess_image(config, img):
model_type = config.model_type
if model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20:
img = img.resize((img.width * 2, img.height * 2))
if model_type == "paligemma":
img = img.convert("RGB")
if model_type not in {"llava_next", "gemma3", "llama4"}:
# TODO: check if this is needed
img = [img]
return img
def get_unpadded_features(
original_height: int,
original_width: int,
@ -208,64 +275,115 @@ def get_number_of_features(height: int, width: int, config) -> int:
return unpadded_features + newline_features + base_features
def scatter_image_embeds(
embeds: torch.Tensor, is_embed: Optional[torch.Tensor]
) -> torch.Tensor:
if is_embed is None:
return embeds
placeholders = embeds.new_full(
(is_embed.shape[0], embeds.shape[-1]),
fill_value=torch.nan,
)
placeholders[is_embed.to(embeds.device)] = embeds
return placeholders
def gather_image_embeds(
embeds: torch.Tensor, is_embed: Optional[torch.Tensor]
) -> Optional[torch.Tensor]:
if is_embed is None:
return embeds
sel = embeds[is_embed.to(embeds.device)]
return sel if sel.numel() else None
@dataclass
class ImagePositions:
offset: int
length: int
id: int
num_placeholder_tokens: int
is_embed: Optional[torch.Tensor] = None
class FlashVlmCausalLMBatch(FlashCausalLMBatch):
image_inputs: Optional[List[List[Dict[str, torch.Tensor]]]]
image_positions: Optional[List[List[ImagePositions]]]
encoder_cache: Optional[List[Dict[int, torch.Tensor]]]
pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
image_grid_thw: Optional[torch.Tensor]
cache_entries_to_free: List[Tuple[int, int]]
has_image_inputs: bool = False
inputs_embeds: Optional[torch.Tensor] = None
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches):
batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches)
def concatenate(cls, batches, padded_total_bs: int = 0):
batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs)
batch.image_inputs = []
batch.image_positions = []
batch.encoder_cache = []
for b in batches:
if b.image_inputs is not None:
batch.image_inputs.extend(b.image_inputs)
else:
batch.image_inputs.append(None)
if b.image_positions is not None:
batch.image_positions.extend(b.image_positions)
else:
batch.image_positions.append(None)
if b.encoder_cache is not None:
batch.encoder_cache.extend(b.encoder_cache)
else:
batch.encoder_cache.append(None)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
batch.inputs_embeds = None
# To be filled in prepare_for_prefill
batch.has_image_inputs = False
batch.cache_entries_to_free = []
return batch
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
image_inputs = []
image_positions = []
encoder_cache = []
for request_id in request_ids:
idx = self.requests_idx_mapping[request_id]
image_inputs.append(self.image_inputs[idx])
image_positions.append(self.image_positions[idx])
encoder_cache.append(self.encoder_cache[idx])
batch = super().filter(request_ids)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
batch.inputs_embeds = None
batch.image_inputs = image_inputs
batch.image_positions = image_positions
batch.encoder_cache = encoder_cache
# To be filled in prepare_for_prefill
batch.has_image_inputs = False
batch.cache_entries_to_free = []
return batch
@classmethod
def batch_tokenized_inputs(
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
):
# Process images first. We need all of them so that the processor
# can make the image splits the same size. And we need the final
# sizes to insert correct number of image tokens.
images = []
for r in requests:
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
pass
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
# qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the
# default warmup image is 20x20
if config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
if image.width <= 20:
w = image.width * 2
h = image.height * 2
image = image.resize((w, h))
if config.model_type == "llava_next":
images.append(image)
elif config.model_type == "gemma3":
images.append(image)
else:
images.append([image])
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
if images:
kwargs = {}
if (
hasattr(processor, "image_processor_class")
@ -273,38 +391,143 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch):
):
kwargs["return_row_col_info"] = True
image_inputs = processor.image_processor(
images, return_tensors="pt", **kwargs
)
else:
image_inputs = None
batch_tokenized_inputs = []
max_length = 0
image_id = 0
vocab = tokenizer.get_vocab()
if not hasattr(config, "image_token_index"):
config.image_token_index = config.image_token_id
batch_tokenized_inputs: List[List[int]] = []
batch_image_inputs: List[Optional[List[dict]]] = []
batch_image_positions: List[Optional[List[ImagePositions]]] = []
for r in requests:
full_text = ""
text_parts = []
image_inputs = []
image_texts = []
image_id = 0
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
full_text += chunk.text
text = preprocess_text(config, chunk.text)
text_parts.append(text)
elif chunk_type == "image":
full_text += image_text_replacement(
processor, image_inputs, config, image_id
)
image_id += 1
img = Image.open(BytesIO(chunk.image.data))
img = preprocess_image(config, img)
full_text = image_text_replacement_fixup(config, full_text)
image_input = processor.image_processor(
[img], return_tensors="pt", **kwargs
)
image_inputs.append(image_input)
img_text, img_start_token_str = image_text_replacement(
processor, image_input, config
)
text_parts.append(img_text)
image_texts.append([image_id, img_start_token_str, img_text])
image_id += 1
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
full_text = image_text_replacement_fixup(config, "".join(text_parts))
input_ids = tokenizer(
full_text,
truncation=True,
max_length=r.truncate,
add_special_tokens=r.add_special_tokens,
add_special_tokens=(
r.add_special_tokens if config.model_type != "paligemma" else False
),
)["input_ids"]
max_length = max(max_length, len(input_ids))
batch_tokenized_inputs.append(input_ids)
return batch_tokenized_inputs, image_inputs
if len(image_inputs) > 0:
img_start_token = vocab[image_texts[0][1]]
image_positions = cls.get_image_positions(
input_ids, image_texts, img_start_token, config, tokenizer
)
else:
image_inputs = None
image_positions = None
batch_tokenized_inputs.append(input_ids)
batch_image_inputs.append(image_inputs)
batch_image_positions.append(image_positions)
return batch_tokenized_inputs, batch_image_inputs, batch_image_positions
@classmethod
def get_image_positions(
cls,
input_ids: List[int],
image_texts: List[Tuple[int, str, str]],
img_start_token: int,
config,
tokenizer: PreTrainedTokenizerBase,
) -> List[ImagePositions]:
image_positions = []
num_images = len(image_texts)
input_ids_t = torch.as_tensor(input_ids)
img_start_token_pos = torch.where(input_ids_t.eq(img_start_token))[0]
num_tokens = input_ids_t.numel()
last_pos = 0
for i in range(num_images):
image_id, img_start_token_str, img_text = image_texts[i]
img_text = image_text_replacement_fixup(config, img_text)
if config.model_type == "gemma3":
img_text = img_text.replace("\n\n", "")
tokens = tokenizer(img_text, add_special_tokens=False, return_tensors="pt")[
"input_ids"
][0]
length = tokens.numel()
assert (
length <= num_tokens
), f"{length} > {num_tokens} Image is truncated, try increasing --max-batch-prefill-tokens"
pos = torch.searchsorted(img_start_token_pos, last_pos, right=False)
index = img_start_token_pos[pos]
assert torch.equal(
input_ids_t[index : index + length], tokens
), "Image tokens not found in input_ids"
is_embed = tokens == config.image_token_index
num_placeholder_tokens = int(is_embed.sum())
if num_placeholder_tokens == length:
is_embed = None
pos = ImagePositions(
offset=index,
length=length,
id=image_id,
num_placeholder_tokens=num_placeholder_tokens,
is_embed=is_embed,
)
image_positions.append(pos)
last_pos = index + length
if (
config.model_type == "idefics2"
and i + 1 != num_images
and input_ids[last_pos] == config.image_token_index
):
fake_token = last_pos - 1
fake_token_index = torch.searchsorted(
img_start_token_pos, fake_token, right=False
)
img_start_token_pos[fake_token_index] = last_pos
image_texts[i + 1][2] = image_texts[i + 1][2][
len(img_start_token_str) :
]
return image_positions
@classmethod
def from_pb_processor(
@ -316,33 +539,164 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch):
dtype: torch.dtype,
device: torch.device,
) -> "FlashVlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor, config
batch_tokenized_inputs, image_inputs, image_positions = (
cls.batch_tokenized_inputs(pb.requests, tokenizer, processor, config)
)
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
if "pixel_attention_mask" in image_inputs:
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
device=device
)
else:
batch.pixel_attention_mask = None
if "image_sizes" in image_inputs:
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
else:
batch.image_sizes = None
if "image_grid_thw" in image_inputs:
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
else:
batch.image_grid_thw = None
else:
batch.image_inputs = image_inputs
batch.image_positions = image_positions
batch.encoder_cache = [{} for _ in range(len(pb.requests))]
if len(image_inputs):
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
return batch
def prepare_for_prefill(
self, max_padded_input_len, max_padded_bs, max_total_tokens
):
super().prepare_for_prefill(
max_padded_input_len, max_padded_bs, max_total_tokens
)
self.has_image_inputs = False
self.cache_entries_to_free = []
self.pixel_values = []
assert (
len(self.cache_lengths)
== len(self.input_lengths)
== len(self.prefilling_mask)
), "Mismatch in lengths of cache_lengths, input_lengths, and prefilling_mask"
for i, (
cache_length,
input_length,
request_prefilling,
) in enumerate(
zip(
self.cache_lengths,
self.input_lengths,
self.prefilling_mask,
)
):
if not request_prefilling or self.image_positions[i] is None:
continue
for image_position in self.image_positions[i]:
if image_position is None:
continue
start_pos = image_position.offset
length = image_position.length
if start_pos >= cache_length + input_length:
# No encoder input required at this step
break
if start_pos + length <= cache_length:
# The encode input is already processed
continue
self.has_image_inputs = True
if image_position.id not in self.encoder_cache[i]:
image_inputs = self.image_inputs[i][image_position.id]
self.pixel_values.append((i, image_position.id, image_inputs))
# Remove the image from the image_inputs
self.image_inputs[i][image_position.id] = None
if not self.has_image_inputs:
self.pixel_values = None
self.pixel_attention_mask = None
self.image_sizes = None
self.image_grid_thw = None
else:
image_grid_thw_list = [
x[2]["image_grid_thw"]
for x in self.pixel_values
if "image_grid_thw" in x[2]
]
if image_grid_thw_list:
self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0)
else:
self.image_grid_thw = None
def update_encoder_cache(self, encoder_outputs, request_id, img_pos):
self.encoder_cache[request_id][img_pos.id] = scatter_image_embeds(
encoder_outputs, img_pos.is_embed
)
def gather_vision_embeds(self):
device = self.input_ids.device
chunks = []
for (
i,
cache_length,
input_length,
request_prefilling,
) in zip(
range(len(self.requests)),
self.cache_lengths,
self.input_lengths,
self.prefilling_mask,
):
if not request_prefilling or self.image_positions[i] is None:
continue
for image_position in self.image_positions[i]:
if image_position is None:
continue
start_pos = image_position.offset
length = image_position.length
if start_pos >= cache_length + input_length:
# No encoder input required at this step
break
if start_pos + length <= cache_length:
# The encode input is already processed
continue
start_idx = max(cache_length - start_pos, 0)
end_idx = min(cache_length - start_pos + input_length, length)
assert (
image_position.id in self.encoder_cache[i]
), f"image_id {image_position.id} not in encoder_cache {self.encoder_cache[i]}"
encoder_output = self.encoder_cache[i][image_position.id]
is_embed = image_position.is_embed
if is_embed is not None:
is_embed = is_embed[start_idx:end_idx]
from loguru import logger
logger.info(
f"image_id {image_position.id} start_idx {start_idx} end_idx {end_idx}, length {length}"
)
embeds = gather_image_embeds(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
if embeds is not None:
chunks.append(embeds)
if end_idx == length:
self.cache_entries_to_free.append((i, image_position.id))
self.image_positions[i][image_position.id] = None
if len(chunks) == 0:
return None
return torch.cat(chunks, dim=0).to(device)
def free_encoder_cache(self):
for i, image_id in self.cache_entries_to_free:
self.encoder_cache[i].pop(image_id, None)
self.cache_entries_to_free = []
class FlashVlmCausalLM(FlashCausalLM):
def __init__(
@ -354,6 +708,7 @@ class FlashVlmCausalLM(FlashCausalLM):
batch_class=FlashVlmCausalLMBatch,
revision,
trust_remote_code: bool,
support_chunking: bool = False,
**kwargs,
):
if PREFIX_CACHING:
@ -371,8 +726,7 @@ class FlashVlmCausalLM(FlashCausalLM):
model_id=model_id,
revision=revision,
trust_remote_code=trust_remote_code,
# FIXME: VLM do not work with context chunking yet
support_chunking=False,
support_chunking=support_chunking,
**kwargs,
)
@ -423,9 +777,12 @@ class FlashVlmCausalLM(FlashCausalLM):
bucketing_ctx=None,
)
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
inputs_embeds = self.get_inputs_embeds(
input_ids=input_ids.to(self.device),
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids),
inputs_embeds=inputs_embeds,
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
@ -433,27 +790,145 @@ class FlashVlmCausalLM(FlashCausalLM):
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=hpu_attention_meta,
lm_head_indices=None,
pixel_values=None,
pixel_attention_mask=None,
image_sizes=None,
image_grid_thw=None,
attention_mask=None,
)
def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
free_mem = HabanaMemoryProfiler.current_free_device_memory()
graph_free_mem = free_mem - self.mem_reserved
graph_free_mem = self.align_workers(
graph_free_mem, torch.distributed.ReduceOp.MIN
)
decode_available_memory = graph_free_mem
msg = (
f"Using {format_bytes(graph_free_mem)}"
f"/{format_bytes(free_mem)} "
"of free device memory for HPUGraphs, "
f"{format_bytes(decode_available_memory)} for decode "
)
log_master(logger.info, msg)
start_time = time.time()
warmup_shape_count = 0
warmup_times = 3
# only warmup decode, for prefill, image pixal size may change, make the warmup useless
def ordering_function_max_bs(b):
return (-b[0], b[1])
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets)
):
buckets = list(
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
)
total_batch_seq = 0.001
total_mem = 0
available_mem = decode_available_memory
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num:
continue
log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
)
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, block_num, False)
if not mem_estimate >= available_mem:
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
self.log_warmup(False, i, len(buckets), batch_size, block_num)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
)
if graphed_bucket in self.graphed_buckets:
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
)
def get_vision_embeds(
self,
pixel_values: torch.Tensor,
pixel_attention_mask: torch.Tensor,
image_sizes: torch.Tensor,
image_grid_thw: torch.Tensor,
):
embeds = self.model.get_vision_embeds(
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
image_sizes=image_sizes,
image_grid_thw=image_grid_thw,
)
return embeds
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: Optional[torch.Tensor] = None,
):
return self.model.get_inputs_embeds(
input_ids=input_ids,
vision_embeds=vision_embeds,
)
def encode_images(self, batch):
if batch.pixel_values is not None:
device = batch.input_ids.device
for request_id, image_id, image_input in batch.pixel_values:
pixel_values = image_input["pixel_values"].to(device)
if "pixel_attention_mask" in image_input:
pixel_attention_mask = image_input["pixel_attention_mask"].to(
device
)
else:
pixel_attention_mask = None
if "image_sizes" in image_input:
image_sizes = image_input["image_sizes"].to(device)
else:
image_sizes = None
if "image_grid_thw" in image_input:
image_grid_thw = image_input["image_grid_thw"]
else:
image_grid_thw = None
encoder_outputs = self.get_vision_embeds(
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
image_sizes=image_sizes,
image_grid_thw=image_grid_thw,
)
batch.update_encoder_cache(
encoder_outputs,
request_id,
batch.image_positions[request_id][image_id],
)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
def set_inputs_embeds(self, batch):
if batch.has_image_inputs:
self.encode_images(batch)
vision_embeds = batch.gather_vision_embeds()
batch.has_image_inputs = False
else:
vision_embeds = None
inputs_embeds = self.get_inputs_embeds(
batch.input_ids, vision_embeds=vision_embeds
)
batch.inputs_embeds = inputs_embeds
def forward(
self,
@ -502,6 +977,7 @@ class FlashVlmCausalLM(FlashCausalLM):
position_ids = new_position_ids
else:
input_ids = batch.input_ids
inputs_embeds = batch.inputs_embeds
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
@ -514,10 +990,25 @@ class FlashVlmCausalLM(FlashCausalLM):
if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
if position_ids.dim() == 1 and batch.prefilling:
position_ids = self.model.get_position_ids(
input_ids, batch.image_grid_thw
input_ids.cpu(), batch.image_grid_thw
)
batch.position_ids = position_ids
attention_mask = None
attention_mask_forward = None
if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None:
attention_mask = self.model.get_attention_mask(
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
)
min_dtype = torch.finfo(self.dtype).min
attention_mask_forward = torch.where(attention_mask, 0, min_dtype).to(
input_ids.device
)
attention_mask = attention_mask.reshape(-1)
if self.model.config.model_type == "llama4":
attention_mask = (input_ids != 0).long()
attention_mask_forward = attention_mask.view(input_lengths.shape[0], -1)
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
@ -526,14 +1017,21 @@ class FlashVlmCausalLM(FlashCausalLM):
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = batch.prefilling
batch_size = input_lengths.shape[0]
seqlen = (
input_ids.shape[0] // batch_size
if batch.prefilling
else batch.hpu_attn_meta.block_list.shape[0]
)
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
batch.prefilling, seqlen, batch_size
)
if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids)
slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
else:
slots_pad = torch.zeros_like(input_ids)
slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[: slots.shape[0]] = slots
slots = slots_pad
@ -541,7 +1039,7 @@ class FlashVlmCausalLM(FlashCausalLM):
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
logits, speculative_logits = self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids),
inputs_embeds=inputs_embeds,
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=kv_cache,
@ -549,18 +1047,9 @@ class FlashVlmCausalLM(FlashCausalLM):
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=batch.hpu_attn_meta,
lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes,
image_grid_thw=batch.image_grid_thw,
attention_mask=attention_mask_forward,
**kwargs,
)
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
if batch.image_grid_thw is not None:
batch.image_grid_thw = None
batch.free_encoder_cache()
return logits, speculative_logits

View File

@ -1,156 +0,0 @@
import re
import torch
import torch.distributed
from transformers import (
PreTrainedTokenizerBase,
)
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
NextTokenChooser,
StoppingCriteria,
)
from text_generation_server.utils.chunks import concat_text_chunks
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
# we split individual characters inside special tokens like [START_DNA]
CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")
# token added to implement a custom sequence tokenization. This token is added at
# corpus cleaning step and removed in pretokenization. The digits are added to increase the chance
# that they do not occur in the corpus. The digits are escaped so that the token does not appear
# literally in the source code in case we ever include it in the training data.
SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"
def _insert_split_marker(m: re.Match):
"""
Applies split marker based on a regex match of special tokens such as
[START_DNA].
Parameters
----------
n : str
Input text to split
Returns
----------
str - the text with the split token added
"""
start_token, _, sequence, end_token = m.groups()
sequence = re.sub(r"(.)", rf"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"
def escape_custom_split_sequence(text):
"""
Applies custom splitting to the text for GALILEO's tokenization
Parameters
----------
text : str
Input text to split
Returns
----------
str - the text with the split token added
"""
return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)
# END CREDIT
class GalacticaCausalLMBatch(CausalLMBatch):
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "GalacticaCausalLMBatch":
inputs = []
next_token_choosers = []
stopping_criterias = []
prefix_offsets = []
top_n_tokens = []
read_offsets = []
requests_idx_mapping = {}
# Parse batch
max_truncation = 0
padding_right_offset = 0
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(
escape_custom_split_sequence(concat_text_chunks(r.input_chunks.chunks))
)
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
tokenized_inputs = tokenizer(
inputs,
return_tensors="pt",
padding=True,
return_token_type_ids=False,
truncation=True,
max_length=max_truncation,
).to(device)
for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(0)
read_offsets.append(input_len)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
input_ids = tokenized_inputs["input_ids"]
# Allocate maximum attention_mask
attention_mask = input_ids.new_zeros(
(pb.size, max_input_length + padding_right_offset)
)
# Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * max_input_length + max_decode_tokens
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=None,
all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(),
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
)

View File

@ -4,14 +4,14 @@ from loguru import logger
from text_generation_server.utils.log import log_master
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
ATTENTION = os.getenv("ATTENTION", "default")
ATTENTION = os.getenv("ATTENTION", "paged")
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in {
"1",
"true",
}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
_expected = {"paged", "default"}
_expected = {"paged"}
assert (
ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}"

View File

@ -1,882 +0,0 @@
from io import BytesIO
from PIL import Image
import torch
import time
from dataclasses import dataclass
from opentelemetry import trace
from transformers import (
AutoConfig,
AutoProcessor,
AutoTokenizer,
PreTrainedTokenizerBase,
ProcessorMixin,
)
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model
from text_generation_server.models.types import (
Batch,
Tokens,
Generation,
GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
import torch.distributed
from text_generation_server.models.custom_modeling.idefics_modeling import (
IdeficsForVisionText2Text,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.quantization import get_loader
tracer = trace.get_tracer(__name__)
@dataclass
class IdeficsCausalLMBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
requests_idx_mapping: Dict[int, int]
# Decoder values
input_ids: torch.Tensor
attention_mask: torch.Tensor
position_ids: torch.Tensor
pixel_values: Optional[torch.Tensor]
image_hidden_states: Optional[torch.Tensor]
image_attention_mask: Optional[torch.Tensor]
past_key_values: Optional[List[Tuple]]
# All tokens
all_input_ids: List[torch.Tensor]
# Lengths of all generations present in the batch
input_lengths: List[int]
prefix_offsets: List[int]
read_offsets: List[int]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
# Metadata used for padding
max_input_length: int
padding_right_offset: int
# Maximum number of tokens this batch will grow to
max_tokens: int
# Past metadata
keys_head_dim_last: bool = True
def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch(
id=self.batch_id,
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
)
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "IdeficsCausalLMBatch":
raise NotImplementedError
@classmethod
def from_pb_processor(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
processor: ProcessorMixin, # Hack
config,
dtype: torch.dtype,
device: torch.device,
) -> "IdeficsCausalLMBatch":
inputs = []
next_token_choosers = []
stopping_criterias = []
prefix_offsets = []
read_offsets = []
requests_idx_mapping = {}
# Parse batch
max_truncation = 0
padding_right_offset = 0
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.input_chunks.chunks)
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
# TODO Check impact on idefics
prompts = []
for inp in inputs:
# Each input is encoded into a list, where each element of this input list is either a string or a URL
prompt = []
for chunk in inp:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
prompt.append(chunk.text)
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
prompt.append(image)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
prompts.append(prompt)
# The processor replaces the call to tokenizer, and
# a/ takes care of fetching images from the URL
# b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model
tokenized_inputs = processor(
prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_truncation,
# TODO Check impact on idefics
# add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
).to(device)
for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(
input_len - 5
) # To decode without potential fallbacks errors
read_offsets.append(
input_len
) # To decode without potential fallbacks errors
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
input_ids = tokenized_inputs["input_ids"]
pixel_values = tokenized_inputs.get("pixel_values", None)
image_hidden_states = None
# Allocate maximum attention_mask
attention_mask = input_ids.new_zeros(
(pb.size, max_input_length + padding_right_offset)
)
# Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
# Do the same for image_attention_mask
if pixel_values is None:
image_attention_mask = None
else:
image_attention_mask = input_ids.new_zeros(
(
pb.size,
max_input_length + padding_right_offset,
pixel_values.size(1),
)
)
image_attention_mask[:, :max_input_length, :] = tokenized_inputs[
"image_attention_mask"
]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(
1, dim=1
) # It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
pixel_values=pixel_values,
image_hidden_states=image_hidden_states,
image_attention_mask=image_attention_mask,
past_key_values=None,
all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(),
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
)
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]:
# It deletes requests from the batch. For instance when client lost connection
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
if len(request_ids) == len(self):
return self
keep_indices = []
# New values after filtering
requests_idx_mapping = {}
requests = []
input_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
max_input_length = 0
next_token_choosers = []
stopping_criterias = []
total_remaining_decode_tokens = 0
new_padding_right_offset = 0
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
requests_idx_mapping[request_id] = i
keep_indices.append(idx)
requests.append(self.requests[idx])
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
all_input_ids.append(self.all_input_ids[idx])
request_input_length = self.input_lengths[idx]
input_lengths.append(request_input_length)
max_input_length = max(max_input_length, request_input_length)
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
total_remaining_decode_tokens += remaining_decode_tokens
new_padding_right_offset = max(
new_padding_right_offset, remaining_decode_tokens
)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids = self.input_ids[keep_indices]
position_ids = self.position_ids[keep_indices]
self.attention_mask = self.attention_mask[
keep_indices,
-(self.padding_right_offset + max_input_length) : (
self.attention_mask.shape[1] - self.padding_right_offset
)
+ new_padding_right_offset,
]
# Do the same for pixel_values and image_attention_mask
pixel_values = self.pixel_values[keep_indices]
self.image_attention_mask = self.image_attention_mask[
keep_indices,
-(self.padding_right_offset + max_input_length) : (
self.image_attention_mask.shape[1] - self.padding_right_offset
)
+ new_padding_right_offset,
:,
]
if self.image_hidden_states is None:
image_hidden_states = None
else:
image_hidden_states = self.image_hidden_states[keep_indices]
# Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) is tuple:
self.past_key_values = [list(layer) for layer in self.past_key_values]
# Update tensors in-place to allow incremental garbage collection
past_kv_length = max_input_length - 1
for layer in self.past_key_values:
past_keys, past_values = layer
if len(past_keys.shape) == 3:
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
if self.keys_head_dim_last:
layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
else:
layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
del past_keys
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
del past_values
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
self.requests = requests
self.requests_idx_mapping = requests_idx_mapping
self.input_ids = input_ids
self.pixel_values = pixel_values
self.image_hidden_states = image_hidden_states
self.position_ids = position_ids
self.all_input_ids = all_input_ids
self.input_lengths = input_lengths
self.prefix_offsets = prefix_offsets
self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
return self
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(
cls, batches: List["IdeficsCausalLMBatch"]
) -> "IdeficsCausalLMBatch":
# It adds new requests to the batch
# Used for padding
total_batch_size = 0
max_input_length = 0
max_num_images = 0
padding_right_offset = 0
for batch in batches:
total_batch_size += len(batch)
max_input_length = max(max_input_length, batch.max_input_length)
max_num_images = max(max_num_images, batch.pixel_values.size(1))
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes
requests = []
requests_idx_mapping = {}
input_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
max_tokens = 0
# Batch tensors
input_ids = None
attention_mask = None
position_ids = None
pixel_values = None
image_hidden_states = None
image_attention_mask = None
past_key_values = []
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + start_index
# Slicing end index for this batch
end_index = start_index + len(batch)
# We only concatenate batches that did at least one step
if batch.past_key_values is None:
raise ValueError("only concatenate prefilled batches")
# Create empty tensor
# input_ids is always of shape [batch_size, 1]
# We do not need to pad it
if input_ids is None:
input_ids = batch.input_ids.new_empty((total_batch_size, 1))
# Copy to correct indices
input_ids[start_index:end_index] = batch.input_ids
# Create padded tensor
if attention_mask is None:
attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_input_length + padding_right_offset),
)
curr_batch_max_num_images = batch.pixel_values.size(1)
if pixel_values is None:
pixel_values = batch.pixel_values.new_zeros(
(total_batch_size, max_num_images, 3, 224, 224)
)
pixel_values[start_index:end_index, :curr_batch_max_num_images] = (
batch.pixel_values
)
if image_attention_mask is None:
image_attention_mask = batch.image_attention_mask.new_zeros(
(
total_batch_size,
max_input_length + padding_right_offset,
max_num_images,
)
)
# We need to slice the attention mask to remove padding from previous steps
# and to remove unused allocated space
left_offset = max_input_length - batch.max_input_length
batch_left_offset = (
batch.attention_mask.shape[1]
- batch.max_input_length
- batch.padding_right_offset
)
attention_mask[
start_index:end_index,
left_offset:-padding_right_offset,
] = batch.attention_mask[
:,
batch_left_offset : -batch.padding_right_offset,
]
image_attention_mask[
start_index:end_index,
left_offset:-padding_right_offset,
:curr_batch_max_num_images,
] = batch.image_attention_mask[
:, batch_left_offset : -batch.padding_right_offset, :
]
# Create empty tensor
# position_ids is always of shape [batch_size, 1]
if position_ids is None:
position_ids = batch.position_ids.new_empty((total_batch_size, 1))
position_ids[start_index:end_index] = batch.position_ids
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
# And ensure that we can update tensors in-place
if isinstance(batch.past_key_values[0], tuple):
batch.past_key_values = [
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
for layer in batch.past_key_values
]
elif len(batch.past_key_values[0][0].shape) == 3:
for layer in batch.past_key_values:
for k, t in enumerate(layer):
layer[k] = t.view(len(batch), -1, *t.shape[-2:])
# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length - batch.max_input_length
) * len(batch)
start_index = end_index
first_past_kvs = batches[0].past_key_values
_, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape
padded_past_values_shape = (
total_batch_size,
num_heads,
max_input_length - 1,
head_dim,
)
if batches[0].keys_head_dim_last:
padded_past_keys_shape = padded_past_values_shape
else:
# seq_length is last for BLOOM
padded_past_keys_shape = (
total_batch_size,
num_heads,
head_dim,
max_input_length - 1,
)
# Iterate over attention layers
# Concatenate past key values layer by layer to allow incremental garbage collection
for j in range(len(first_past_kvs)):
padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
start_index = 0
for batch in batches:
past_keys = batch.past_key_values[j][0]
# Clear reference to the original tensor
batch.past_key_values[j][0] = None
# Slicing end index for this batch
end_index = start_index + len(batch)
# We slice the keys to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
if batch.keys_head_dim_last:
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
past_keys[:, :, -past_seq_len:, :]
)
else:
# BLOOM case
padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
past_keys[:, :, :, -past_seq_len:]
)
del past_keys
start_index = end_index
padded_past_values = first_past_kvs[j][1].new_zeros(
padded_past_values_shape
)
start_index = 0
for batch in batches:
past_values = batch.past_key_values[j][1]
# Clear reference to the original tensor
batch.past_key_values[j][1] = None
# Slicing end index for this batch
end_index = start_index + len(batch)
# We slice the past values to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
past_values[:, :, -past_seq_len:, :]
)
del past_values
# Update values
start_index = end_index
past_key_values.append([padded_past_keys, padded_past_values])
return cls(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
pixel_values=pixel_values,
image_hidden_states=image_hidden_states,
image_attention_mask=image_attention_mask,
past_key_values=past_key_values,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length,
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
max_tokens=max_tokens,
)
def __len__(self):
return len(self.requests)
class IdeficsCausalLM(Model):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed()
device = torch.device("hpu")
dtype = torch.bfloat16 if dtype is None else dtype
self.device, self.dtype = device, dtype
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
config.vision_config.quantize = quantize
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
weights_loader=weights_loader,
)
model = IdeficsForVisionText2Text(config, weights)
self.config = config
torch.distributed.barrier(group=self.process_group)
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@property
def batch_type(self) -> Type[IdeficsCausalLMBatch]:
return IdeficsCausalLMBatch
def forward(
self,
input_ids,
attention_mask,
position_ids,
pixel_values,
image_hidden_states,
image_attention_mask,
past_key_values: Optional = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward
kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"image_hidden_states": image_hidden_states,
"image_attention_mask": image_attention_mask,
"past_key_values": past_key_values,
"use_cache": True,
"return_dict": True,
}
if self.has_position_ids:
kwargs["position_ids"] = position_ids
outputs, speculative_logits = self.model.forward(**kwargs)
return (
outputs.logits,
speculative_logits,
outputs.past_key_values,
outputs.image_hidden_states,
)
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: IdeficsCausalLMBatch
) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch], Tuple[int, int]]:
start = time.time_ns()
# slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
if batch.image_attention_mask is None:
image_attention_mask = None
else:
if batch.input_ids.size(1) == 1:
# THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images),
# but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
# token need to attend to the encoder hidden states (i.e. the vision encoder)
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
image_attention_mask = batch.image_attention_mask[
:, -(batch.padding_right_offset + 1)
].unsqueeze(1)
else:
image_attention_mask = batch.image_attention_mask[
:, : -batch.padding_right_offset
]
logits, speculative_logits, past, image_hidden_states = self.forward(
input_ids=batch.input_ids,
attention_mask=attention_mask,
position_ids=batch.position_ids,
pixel_values=batch.pixel_values,
image_hidden_states=batch.image_hidden_states,
image_attention_mask=image_attention_mask,
past_key_values=batch.past_key_values,
)
# Hardcoded remove image tokens
logits[:, 32000:32001] = torch.finfo(logits.dtype).min
start_decode = time.time_ns()
# Results
generations: List[Generation] = []
stopped = True
# Zipped iterator
iterator = zip(
batch.requests,
batch.input_lengths,
batch.prefix_offsets,
batch.read_offsets,
logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
# For each member of the batch
for i, (
request,
input_length,
prefix_offset,
read_offset,
logits,
next_token_chooser,
stopping_criteria,
all_input_ids,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
all_input_ids.view(1, -1), logits[-1:, :]
)
# Append next token to all tokens
all_input_ids = torch.cat([all_input_ids, next_token_id])
new_input_length = input_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids[:, 0], prefix_offset, read_offset
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token_id_squeezed,
next_token_text,
)
if not stop:
stopped = False
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text, _, _ = self.decode_token(
all_input_ids[:, 0],
prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens
- 1,
read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
generated_text = None
# Prefill
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + torch.log_softmax(
logits, -1
).gather(1, all_input_ids[1:]).squeeze(1)[
-new_input_length:-1
].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = Tokens(
prefill_token_ids,
prefill_logprobs,
prefill_texts,
is_special=[],
)
else:
prefill_tokens = None
top_tokens = None
generation = Generation(
request.id,
prefill_tokens,
Tokens(
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
top_tokens,
)
generations.append(generation)
# Update values
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
next_token_id_squeezed.item()
)
batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.max_input_length = max(batch.max_input_length, new_input_length)
# We finished all generations in the batch; there is no next batch
if stopped:
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
# Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1]
# Update attention_mask as we added a new token to input_ids
batch.attention_mask[:, -batch.padding_right_offset] = 1
batch.image_attention_mask[:, -batch.padding_right_offset, :] = (
batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
)
# Decrease right offset
batch.padding_right_offset -= 1
# Update position_ids
batch.position_ids = batch.position_ids[:, -1:] + 1
# Update past key values
batch.past_key_values = past
batch.image_hidden_states = image_hidden_states
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)

View File

@ -1,814 +0,0 @@
import torch
import torch.distributed
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from typing import Optional
from text_generation_server.models.custom_modeling.mamba_modeling import (
MambaConfig,
)
from loguru import logger
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.models.globals import CUDA_GRAPHS, MEM_POOL
import time
from text_generation_server.models.custom_modeling.mamba_modeling import (
MambaModel,
InferenceParams,
)
from text_generation_server.models import Model
from typing import Any, List, Tuple, Type, Dict
from text_generation_server.models.types import (
Batch,
Tokens,
Generation,
GeneratedText,
)
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria
def new_inference_params(
n_blocks: int,
batch_size: int,
d_inner: int,
d_conv: int,
d_state: int,
seqlen_offset: int,
dtype: torch.dtype,
device: torch.device,
):
max_seqlen = 0
conv_states = torch.zeros(
(
n_blocks,
batch_size,
d_inner,
d_conv,
),
device=device,
dtype=dtype,
)
ssm_states = torch.zeros(
(
n_blocks,
batch_size,
d_inner,
d_state,
),
device=device,
dtype=dtype,
)
inference_params = InferenceParams(
max_seqlen=max_seqlen,
max_batch_size=batch_size,
seqlen_offset=seqlen_offset,
conv_states=conv_states,
ssm_states=ssm_states,
)
return inference_params
@dataclass
class MambaBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
requests_idx_mapping: Dict[int, int]
# Decoder values
input_ids: torch.Tensor
# All tokens
all_input_ids: List[torch.Tensor]
# Lengths of all generations present in the batch
input_lengths: List[int]
prefix_offsets: List[int]
read_offsets: List[int]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Metadata used for padding
max_input_length: int
padding_right_offset: int
# Maximum number of tokens this batch will grow to
max_tokens: int
# Past metadata
keys_head_dim_last: bool = True
# Inference params
inference_params: Optional[Dict[str, Any]] = None
def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch(
id=self.batch_id,
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
)
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "MambaBatch":
inputs = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
prefix_offsets = []
read_offsets = []
requests_idx_mapping = {}
# Parse batch
max_truncation = 0
padding_right_offset = 0
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(concat_text_chunks(r.input_chunks.chunks))
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
tokenized_inputs = tokenizer(
inputs,
return_tensors="pt",
padding=True,
return_token_type_ids=False,
truncation=True,
max_length=max_truncation,
).to(device)
for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(input_len - 5)
read_offsets.append(input_len)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
input_ids = tokenized_inputs["input_ids"]
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
# past_input_ids=None,
all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(),
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
)
def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]:
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
if len(request_ids) == len(self):
return self
keep_indices = []
# New values after filtering
requests_idx_mapping = {}
requests = []
input_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
max_input_length = 0
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
total_remaining_decode_tokens = 0
new_padding_right_offset = 0
indices = []
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
requests_idx_mapping[request_id] = i
keep_indices.append(idx)
requests.append(self.requests[idx])
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
all_input_ids.append(self.all_input_ids[idx])
request_input_length = self.input_lengths[idx]
input_lengths.append(request_input_length)
max_input_length = max(max_input_length, request_input_length)
indices.append(idx)
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
total_remaining_decode_tokens += remaining_decode_tokens
new_padding_right_offset = max(
new_padding_right_offset, remaining_decode_tokens
)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids = self.input_ids[keep_indices]
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
self.requests = requests
self.requests_idx_mapping = requests_idx_mapping
self.input_ids = input_ids
self.all_input_ids = all_input_ids
self.input_lengths = input_lengths
self.prefix_offsets = prefix_offsets
self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens
self.top_n_tokens_tensor = top_n_tokens_tensor
self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
# TODO
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
self.inference_params.conv_states = self.inference_params.conv_states[
:, indices
]
self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices]
return self
@classmethod
def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch":
# Used for padding
total_batch_size = 0
max_input_length = 0
padding_right_offset = 0
for batch in batches:
total_batch_size += len(batch)
max_input_length = max(max_input_length, batch.max_input_length)
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes
requests = []
requests_idx_mapping = {}
input_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
max_tokens = 0
seqlen_offset = 0
(n_blocks, _, d_inner, d_conv) = batches[0].inference_params.conv_states.shape
(_, _, _, d_state) = batches[0].inference_params.ssm_states.shape
dtype = batches[0].inference_params.conv_states.dtype
device = batches[0].inference_params.conv_states.device
inference_params = new_inference_params(
n_blocks=n_blocks,
batch_size=total_batch_size,
d_state=d_state,
d_conv=d_conv,
d_inner=d_inner,
seqlen_offset=seqlen_offset,
device=device,
dtype=dtype,
)
# Batch tensors
input_ids = None
top_n_tokens_tensor = None
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + start_index
# Slicing end index for this batch
end_index = start_index + len(batch)
# Create empty tensor
# input_ids is always of shape [batch_size, 1]
# We do not need to pad it
if input_ids is None:
input_ids = batch.input_ids.new_empty((total_batch_size, 1))
# Copy to correct indices
input_ids[start_index:end_index] = batch.input_ids
if top_n_tokens_tensor is None:
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length - batch.max_input_length
) * len(batch)
inference_params.max_seqlen = max(
inference_params.max_seqlen, batch.inference_params.max_seqlen
)
assert batch.inference_params.seqlen_offset != 0, "Invalid seqlen offset"
inference_params.seqlen_offset = max(
inference_params.seqlen_offset, batch.inference_params.seqlen_offset
)
inference_params.conv_states[:, start_index:end_index] = (
batch.inference_params.conv_states
)
inference_params.ssm_states[:, start_index:end_index] = (
batch.inference_params.ssm_states
)
start_index = end_index
return cls(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length,
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
max_tokens=max_tokens,
inference_params=inference_params,
)
def __len__(self):
return len(self.requests)
class Mamba(Model):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.quantize = quantize
self.process_group, _rank, world_size = initialize_torch_distributed()
if world_size > 1:
raise RuntimeError("Mamba does not support Tensor Parallelism (TP)")
self.cuda_graphs = {}
if torch.cuda.is_available():
device = torch.device("cuda")
# Bf16 is important. In f16 accumulations in the matmul are causing
# differences while the server is under load.
# This is detectable by the integration load test
dtype = torch.bfloat16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
"EleutherAI/gpt-neox-20b",
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = MambaConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
tokenizer.bos_token_id = config.bos_token_id
tokenizer.eos_token_id = config.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
weights_loader=weights_loader,
)
model = MambaModel(config, weights)
torch.distributed.barrier(group=self.process_group)
super(Mamba, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)
@property
def batch_type(self) -> Type[MambaBatch]:
return MambaBatch
def warmup(self, batch) -> Optional[int]:
# TODO: implement warmup for Mamba if needed
if CUDA_GRAPHS:
if self.speculate is None or self.speculate == 0:
try:
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
# Warmup cuda graphs
for bs in CUDA_GRAPHS:
self.cuda_graph_warmup(bs)
except Exception:
logger.exception("Decode cuda graph warmup failed")
else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
return None
def cuda_graph_warmup(self, batch_size: int):
input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)
n_blocks = len(self.model.blocks)
d_state = self.model.config.d_state
d_conv = self.model.config.d_conv
# Inner takes the expand multiplication
d_inner = self.model.config.d_inner
# Important seqlen_offset to go through the update mecanism with the state
seqlen_offset = 1
inference_params = new_inference_params(
n_blocks=n_blocks,
batch_size=batch_size,
d_state=d_state,
d_conv=d_conv,
d_inner=d_inner,
seqlen_offset=seqlen_offset,
device=self.device,
dtype=self.dtype,
)
graph = torch.cuda.CUDAGraph()
torch.cuda.synchronize()
# Run once outside to warmup
self.model.forward(input_ids=input_ids, inference_params=inference_params)
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL):
logits, speculative_logits = self.model.forward(
input_ids=input_ids, inference_params=inference_params
)
torch.cuda.synchronize()
graph_dict = {
"input_ids": input_ids,
"inference_params": inference_params,
"graph": graph,
"logits": logits,
"speculative_logits": speculative_logits,
}
self.cuda_graphs[batch_size] = graph_dict
def tunableop_warmup(self, batch_size: int, seqlen: int):
input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)
n_blocks = len(self.model.blocks)
d_state = self.model.config.d_state
d_conv = self.model.config.d_conv
# Inner takes the expand multiplication
d_inner = self.model.config.d_inner
# Important seqlen_offset to go through the update mecanism with the state
seqlen_offset = 1
inference_params = new_inference_params(
n_blocks=n_blocks,
batch_size=seqlen,
d_state=d_state,
d_conv=d_conv,
d_inner=d_inner,
seqlen_offset=seqlen_offset,
device=self.device,
dtype=self.dtype,
)
self.model.forward(input_ids=input_ids, inference_params=inference_params)
def forward(
self, input_ids: torch.Tensor, inference_params: Any
) -> Tuple[torch.Tensor, torch.Tensor]:
bs = input_ids.shape[0]
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None)
is_prefill = inference_params is None or inference_params.seqlen_offset == 0
if is_prefill or cuda_graph is None:
return self.model(
input_ids,
inference_params=inference_params,
)
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][:bs] = input_ids
cuda_graph["inference_params"].conv_states[
:, :bs
] = inference_params.conv_states
cuda_graph["inference_params"].ssm_states[:, :bs] = inference_params.ssm_states
# Replay the graph
cuda_graph["graph"].replay()
inference_params.conv_states.copy_(
cuda_graph["inference_params"].conv_states[:, :bs]
)
inference_params.ssm_states.copy_(
cuda_graph["inference_params"].ssm_states[:, :bs]
)
# Slice output to the correct shape
speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
start = time.time_ns()
input_ids = (
batch.input_ids
) # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
batch_size, max_seqlen = input_ids.shape
# Inference params
if batch.inference_params is None:
# 0 is important here
seqlen_offset = 0
n_blocks = len(self.model.blocks)
d_state = self.model.config.d_state
d_conv = self.model.config.d_conv
d_inner = self.model.config.d_inner
inference_params = new_inference_params(
n_blocks=n_blocks,
batch_size=batch_size,
d_state=d_state,
d_conv=d_conv,
d_inner=d_inner,
seqlen_offset=seqlen_offset,
device=self.device,
dtype=self.dtype,
)
batch.inference_params = inference_params
# Forward pass
logits, speculative_logits = self.forward(
input_ids, inference_params=batch.inference_params
)
# batch.inference_params = new_inference_params
# Results
generations: List[Generation] = []
stopped = True
# Speculation is not active for causal
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.log_softmax(logits[:, -1], -1),
accepted_ids,
)
start_decode = time.time_ns()
# Zipped iterator
iterator = zip(
batch.requests,
batch.input_lengths,
batch.prefix_offsets,
batch.read_offsets,
logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
batch.top_n_tokens,
batch_top_token_ids,
batch_top_token_logprobs,
)
# For each member of the batch
for i, (
request,
input_length,
prefix_offset,
read_offset,
logits,
next_token_chooser,
stopping_criteria,
all_input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
all_input_ids.view(1, -1), logits[-1:, :]
)
# Append next token to all tokens
all_input_ids = torch.cat([all_input_ids, next_token_id])
new_input_length = input_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids[:, 0], prefix_offset, read_offset
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token_id_squeezed,
next_token_text,
)
if not stop:
stopped = False
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text, _, _ = self.decode_token(
all_input_ids[:, 0],
prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens
- 1,
read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
generated_text = None
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + torch.log_softmax(
logits, -1
).gather(1, all_input_ids[1:]).squeeze(1)[
-new_input_length:-1
].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = Tokens(
prefill_token_ids,
prefill_logprobs,
prefill_texts,
is_special=[],
)
else:
prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
else:
top_tokens = None
generation = Generation(
request.id,
prefill_tokens,
Tokens(
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
top_tokens,
)
generations.append(generation)
# Update values
batch.next_token_choosers[i] = batch.next_token_choosers[
i
].advance_grammar(next_token_id_squeezed.item())
batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.max_input_length = max(batch.max_input_length, new_input_length)
# We finished all generations in the batch; there is no next batch
if stopped:
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
# Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1]
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)

View File

@ -32,6 +32,9 @@ from text_generation_server.utils.import_utils import (
)
import torch.nn.functional as F
from text_generation_server.utils.log import log_master
import time
import os
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
tracer = trace.get_tracer(__name__)
@ -43,10 +46,17 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
aspect_ratio_mask: Optional[torch.Tensor] = None
cross_attention_states: Optional[torch.Tensor] = None
def prepare_for_prefill(
self, max_padded_input_len, max_padded_bs, max_total_tokens
):
super(FlashVlmCausalLMBatch, self).prepare_for_prefill(
max_padded_input_len, max_padded_bs, max_total_tokens
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches):
batch = super().concatenate(batches)
def concatenate(cls, batches, padded_total_bs: int = 0):
batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs)
batch.pixel_values = None
batch.pixel_attention_mask = None
@ -70,7 +80,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
assert self.image_indices is not None
batch = super().filter(request_ids)
batch = super(FlashVlmCausalLMBatch, self).filter(request_ids)
assert self.image_indices is not None
indices = []
for i, request_id in enumerate(request_ids):
@ -96,6 +106,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
]
else:
batch.cross_attention_states = None
batch.pixel_values = None
return batch
@classmethod
@ -187,7 +198,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
else:
input_ids = batch.input_ids[0]
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64)
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
@ -225,6 +236,10 @@ def generate_cross_attention_states(
class FlashMllamaCausalLM(FlashVlmCausalLM):
def set_inputs_embeds(self, batch):
# Set the input embeddings to None, as we are using the input_ids for the model
batch.inputs_embeds = None
def warmup_decode(
self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch
):
@ -267,6 +282,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
cross_attention_states, image_indices, input_lengths, 1, False
)
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
False, hpu_attention_meta.block_list.shape[0], batch_size
)
self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids),
@ -280,6 +300,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
cross_attention_states=cross_attention_states,
indices=_async_h2d_tensor_copy(indices),
cross_attention_len=_async_h2d_tensor_copy(cross_attention_len),
**kwargs,
)
def warmup_prefill(
@ -325,7 +346,9 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
)
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
True, prompt_len, batch_size
)
self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids),
@ -343,26 +366,103 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
)
def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3"))
free_mem = HabanaMemoryProfiler.current_free_device_memory()
graph_free_mem = free_mem - self.mem_reserved
graph_free_mem = self.align_workers(
graph_free_mem, torch.distributed.ReduceOp.MIN
)
prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
decode_available_memory = graph_free_mem - prompt_available_memory
msg = (
f"Using {format_bytes(graph_free_mem)}"
f"/{format_bytes(free_mem)} "
"of free device memory for HPUGraphs, "
f"{format_bytes(prompt_available_memory)} for prompt and "
f"{format_bytes(decode_available_memory)} for decode "
f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})"
)
log_master(logger.info, msg)
start_time = time.time()
warmup_shape_count = 0
warmup_times = 3
self.bucketing_ctx.generate_prompt_buckets()
for i, (batch_size, seq_len) in enumerate(
reversed(self.bucketing_ctx.prompt_buckets)
def ordering_function_min_tokens(b):
return (b[0] * b[1], b[1], b[0])
buckets = list(
sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
)
graph_free_mem
total_batch_seq = 0.001
total_mem = 0
available_mem = prompt_available_memory
for i, (batch_size, seq_len) in enumerate(buckets):
if batch_size * seq_len > self.max_batch_prefill_tokens:
continue
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size * seq_len
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, seq_len, True)
if not (
mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture
):
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
self.log_warmup(True, i, len(buckets), batch_size, seq_len)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
)
if graphed_bucket in self.graphed_buckets:
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
def ordering_function_max_bs(b):
return (-b[0], b[1])
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets)
):
buckets = list(
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
)
free_mem = HabanaMemoryProfiler.current_free_device_memory()
total_batch_seq = 0.001
total_mem = 0
available_mem = free_mem - self.mem_reserved
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num:
continue
log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
)
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, block_num, False)
if not mem_estimate >= available_mem:
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
self.log_warmup(False, i, len(buckets), batch_size, block_num)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
)
if graphed_bucket in self.graphed_buckets:
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
)
def forward(
self,
@ -438,15 +538,22 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = (
batch.prefilling if self.limit_hpu_graphs else False
batch_size = input_lengths.shape[0]
seqlen = (
input_ids.shape[0] // batch_size
if batch.prefilling
else batch.hpu_attn_meta.block_list.shape[0]
)
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
batch.prefilling, seqlen, batch_size
)
if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids)
slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
else:
slots_pad = torch.zeros_like(input_ids)
slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[: slots.shape[0]] = slots
slots = slots_pad
orig_bs = len(batch)
@ -475,7 +582,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
logits, speculative_logits = self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids),
input_ids=input_ids,
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=kv_cache,

View File

@ -1,71 +0,0 @@
from io import BytesIO
from PIL import Image
import torch
import torch.distributed
from opentelemetry import trace
from typing import Iterable
from text_generation_server.models.flash_vlm_causal_lm import (
FlashVlmCausalLMBatch,
image_text_replacement,
)
from text_generation_server.pb.generate_pb2 import Request
tracer = trace.get_tracer(__name__)
class PaliGemmaBatch(FlashVlmCausalLMBatch):
@classmethod
def batch_tokenized_inputs(
cls, requests: Iterable[Request], tokenizer, processor, config
):
batch_inputs = []
image_inputs = []
max_truncation = 0
for r in requests:
full_text = ""
image_id = 0
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
full_text += "<bos>" + chunk.text + "\n"
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
# TODO do_convert_RGB should be on by default ?
image = image.convert("RGB")
image_input = processor.image_processor(image, return_tensors="pt")
full_text += image_text_replacement(
processor, image_input, config, image_id
)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs,
truncation=True,
max_length=max_truncation,
add_special_tokens=False,
)["input_ids"]
if image_inputs:
image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
}
if "pixel_attention_mask" in image_input:
new_image_inputs["pixel_attention_mask"] = torch.cat(
[img["pixel_attention_mask"] for img in image_inputs], dim=0
)
if "image_sizes" in image_input:
new_image_inputs["image_sizes"] = torch.cat(
[img["image_sizes"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
else:
image_inputs = None
return batch_tokenized_inputs, image_inputs

View File

@ -1,47 +0,0 @@
import torch
from dataclasses import dataclass
from typing import List, Optional, Type
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
@dataclass
class StarCoderCausalLMBatch(CausalLMBatch):
past_key_values: Optional[List[torch.Tensor]]
def detach_kv_cache(self):
past_keys = []
past_values = []
last_dim = int(self.past_key_values[0].size(dim=-1) / 2)
for key_value in self.past_key_values:
past_keys.append(key_value.split((last_dim, last_dim), dim=-1)[0])
past_values.append(key_value.split((last_dim, last_dim), dim=-1)[1])
del self.past_key_values
return past_keys, past_values
def attach_kv_cache(self, past_keys, past_values):
self.past_key_values = [
torch.cat((key, value), dim=-1)
for key, value in zip(past_keys, past_values)
]
class StarCoder(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
super(StarCoder, self).__init__(
model_id=model_id,
revision=revision,
dtype=dtype,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return StarCoderCausalLMBatch

View File

@ -23,26 +23,8 @@ from text_generation_server.models.globals import set_adapter_to_index
from text_generation_server.utils.adapter import AdapterInfo
from text_generation_server.utils.tokens import make_tokenizer_optional
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
from text_generation_server.models import VLM_BATCH_TYPES
try:
from text_generation_server.models.pali_gemma import PaliGemmaBatch
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
from text_generation_server.models.vlm_causal_lm import (
VlmCausalLMBatch,
)
from text_generation_server.models.flash_vlm_causal_lm import (
FlashVlmCausalLMBatch,
)
VLM_BATCH_TYPES = {
PaliGemmaBatch,
VlmCausalLMBatch,
FlashVlmCausalLMBatch,
FlashMllamaCausalLMBatch,
}
except (ImportError, NotImplementedError):
# These imports can fail on CPU/Non flash.
VLM_BATCH_TYPES = set()
from text_generation_server.utils.version import (
is_driver_compatible,
MIN_TGI_GAUDI_SYNAPSE_VERSION,
@ -224,6 +206,7 @@ def serve(
quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[str],
kv_cache_dtype: Optional[str],
trust_remote_code: bool,
uds_path: Path,
max_input_tokens: int,
@ -236,6 +219,7 @@ def serve(
quantize: Optional[str] = None,
speculate: Optional[int] = None,
dtype: Optional[str] = None,
kv_cache_dtype: Optional[str] = None,
trust_remote_code: bool = False,
):
if not is_driver_compatible():
@ -279,6 +263,7 @@ def serve(
quantize,
speculate,
data_type,
kv_cache_dtype,
trust_remote_code,
max_input_tokens,
adapter_to_index,
@ -326,6 +311,7 @@ def serve(
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
)
)

View File

@ -31,6 +31,7 @@ def main(args):
trust_remote_code=args.trust_remote_code,
uds_path=args.uds_path,
max_input_tokens=args.max_input_tokens,
kv_cache_dtype="auto",
)

View File

@ -4,8 +4,8 @@ import os
import glob
import time
from optimum.habana.utils import to_gb_rounded
import habana_frameworks.torch as htorch
import numpy as np
START_TS = None
DBG_TRACE_FILENAME = os.environ.get("DBG_TRACE_FILENAME")
@ -14,6 +14,19 @@ if "GRAPH_VISUALIZATION" in os.environ:
os.remove(f)
def to_gb_rounded(mem: float) -> float:
"""
Rounds and converts to GB.
Args:
mem (float): memory in bytes
Returns:
float: memory in GB rounded to the second decimal
"""
return np.round(mem / 1024**3, 2)
def count_hpu_graphs():
return len(glob.glob(".graph_dumps/*PreGraph*"))

View File

@ -1,18 +1,9 @@
import torch
from loguru import logger
def get_hpu_free_memory(device, memory_fraction):
from habana_frameworks.torch.hpu import memory_stats
device_id = device.index
mem_stats = memory_stats(device_id)
logger.info(f"mem_stats: {mem_stats}")
total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"]
free_memory = max(
0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"])
)
return free_memory
free_hpu_memory, _ = torch.hpu.mem_get_info()
return free_hpu_memory
def synchronize_hpu(device):

View File

@ -1,7 +1,7 @@
import json
import os
from dataclasses import dataclass
from typing import Optional
from typing import Optional, List
from huggingface_hub import hf_hub_download
from text_generation_server.utils.weights import (
@ -18,6 +18,8 @@ class _QuantizerConfig:
groupsize: int
quant_method: str
sym: bool
weight_block_size: Optional[List[int]]
modules_to_not_convert: List[str]
@dataclass
@ -25,7 +27,20 @@ class _FP8QuantizerConfig:
activation_scale_ub: float
# We should probably do this with Pytantic JSON deserialization,
def _get_config_json(model_id: str, revision: Optional[str], filename: str):
if os.path.exists(
os.path.join(
model_id,
)
):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(model_id, filename=filename, revision=revision)
with open(filename, "r") as f:
return json.load(f)
# We should probably do this with Pydantic JSON deserialization,
# but for now we'll stay close to the old _set_gptq_params.
def _get_quantizer_config(model_id, revision):
bits = 4
@ -34,21 +49,18 @@ def _get_quantizer_config(model_id, revision):
checkpoint_format = None
sym = False
desc_act = False
weight_block_size = None
modules_to_not_convert = []
filename = "config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(model_id, filename=filename, revision=revision)
with open(filename, "r") as f:
data = json.load(f)
data = _get_config_json(model_id, revision, filename)
# FP8 config
if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
return _FP8QuantizerConfig(
activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
)
weight_block_size = data["quantization_config"].get("weight_block_size", None)
if "zero_point" in data["quantization_config"]:
sym = not data["quantization_config"]["zero_point"]
@ -61,18 +73,16 @@ def _get_quantizer_config(model_id, revision):
# Order is important here, desc_act is missing on some real models
quant_method = data["quantization_config"]["quant_method"]
checkpoint_format = data["quantization_config"].get("checkpoint_format")
desc_act = data["quantization_config"]["desc_act"]
desc_act = data["quantization_config"].get("desc_act", False)
modules_to_not_convert = data["quantization_config"].get(
"modules_to_not_convert", []
)
if modules_to_not_convert is None:
modules_to_not_convert = []
except Exception:
filename = "quantize_config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
data = _get_config_json(model_id, revision, filename)
bits = data["bits"]
groupsize = data["group_size"]
@ -88,14 +98,7 @@ def _get_quantizer_config(model_id, revision):
except Exception:
filename = "quant_config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
data = _get_config_json(model_id, revision, filename)
bits = data["w_bit"]
groupsize = data["q_group_size"]
desc_act = data["desc_act"]
@ -111,12 +114,21 @@ def _get_quantizer_config(model_id, revision):
checkpoint_format=checkpoint_format,
sym=sym,
desc_act=desc_act,
weight_block_size=weight_block_size,
modules_to_not_convert=modules_to_not_convert,
)
def get_loader(
quantize: Optional[str], model_id: str, revision: Optional[str]
) -> WeightsLoader:
if quantize == "compressed-tensors":
config = _get_config_json(model_id, revision, "config.json")
from text_generation_server.layers.compressed_tensors import (
CompressedTensorsLoader,
)
return CompressedTensorsLoader(config)
quantizer_config = _get_quantizer_config(model_id, revision)
if quantize in {"awq", "gptq"}:
from text_generation_server.layers.gptq import GPTQWeightsLoader
@ -134,6 +146,7 @@ def get_loader(
quant_method=quantizer_config.quant_method,
quantize=quantize,
sym=quantizer_config.sym,
modules_to_not_convert=quantizer_config.modules_to_not_convert,
)
elif quantize == "fp8" or quantize is None:
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
@ -141,9 +154,14 @@ def get_loader(
# Since the default for the quantize config is _QuantizerConfig,
# we need to add this check to not get an attribute error
activation_scale_ub = None
weight_block_size = quantizer_config.weight_block_size
if isinstance(quantizer_config, _FP8QuantizerConfig):
activation_scale_ub = quantizer_config.activation_scale_ub
return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8")
return HybridFP8UnquantLoader(
activation_scale_ub,
to_fp8=quantize == "fp8",
weight_block_size=weight_block_size,
)
else:
raise ValueError(f"Unknown quantization method: {quantize}")

View File

@ -1,5 +1,30 @@
from optimum.habana.utils import get_driver_version
from packaging.version import Version
from packaging import version
import subprocess
def get_driver_version():
"""
Returns the driver version.
"""
# Enable console printing for `hl-smi` check
output = subprocess.run(
"hl-smi",
shell=True,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env={"ENABLE_CONSOLE": "true"},
)
if output.returncode == 0 and output.stdout:
return version.parse(
output.stdout.split("\n")[2]
.replace(" ", "")
.split(":")[1][:-1]
.split("-")[0]
)
return None
MIN_TGI_GAUDI_SYNAPSE_VERSION = Version("1.19.0")

View File

@ -62,6 +62,14 @@ class WeightsLoader(ABC):
"""
...
@abstractmethod
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
"""
Get the weights at the given prefixes, column-split them for tensor
parallelim, and then concatenate the weights along the given dimension.
"""
...
@abstractmethod
def get_weights_row(self, weights: "Weights", prefix: str):
"""
@ -130,6 +138,10 @@ class DefaultWeightsLoader(WeightsLoader):
weights.get_sharded(f"{prefix}.weight", dim=1),
)
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
w = [weights.get_tensor(f"{p}.weight") for p in prefixes]
return self.weight_class(torch.cat(w, dim=dim))
class Weights:
def __init__(
@ -303,7 +315,7 @@ class Weights:
world_size = self.process_group.size()
rank = self.process_group.rank()
tensors = []
tensors_slices = []
block_offset = 0
for block_size in block_sizes:
assert (
@ -312,15 +324,18 @@ class Weights:
shard_block_size = block_size // world_size
start = rank * shard_block_size
stop = (rank + 1) * shard_block_size
if dim == 0:
tensor = slice_[block_offset + start : block_offset + stop]
elif dim == 1:
tensor = slice_[:, block_offset + start : block_offset + stop]
else:
raise NotImplementedError("Currently only dim=0 or dim=1 is supported")
tensors.append(tensor)
tensors_slices += range(block_offset + start, block_offset + stop)
block_offset += block_size
tensor = torch.cat(tensors, dim=dim)
if dim == 0:
tensor = slice_[tensors_slices, ...]
elif dim == 1 or dim == -2:
tensor = slice_[:, tensors_slices, ...]
elif dim == 2 or dim == -1:
tensor = slice_[..., tensors_slices]
else:
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
tensor = tensor.to(device=self.device)
# Avoid casting quantizer dtypes.
@ -390,6 +405,9 @@ class Weights:
def get_weights_row(self, prefix: str):
return self.weights_loader.get_weights_row(self, prefix)
def get_multi_weights(self, prefixes: List[str], dim: int):
return self.weights_loader.get_multi_weights(self, prefixes, dim)
@contextmanager
def use_loader(self, weights_loader: WeightsLoader):
"""

View File

@ -7,7 +7,8 @@ from typing import List, Optional, Tuple
import torch
from loguru import logger
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from optimum.neuron.configuration_utils import NeuronConfig
from transformers.generation import GenerationConfig
from optimum.neuron import NeuronModelForCausalLM
@ -175,6 +176,12 @@ class Slot:
self._generation_config.top_p = request.parameters.top_p
if request.parameters.typical_p != 0:
self._generation_config.typical_p = request.parameters.typical_p
else:
# Set the sampling parameters to emulate greedy decoding when using on-device sampling
self._generation_config.temperature = 1.0
self._generation_config.top_k = 1
self._generation_config.top_p = 1.0
self._generation_config.typical_p = 1.0
if request.parameters.repetition_penalty != 0:
self._generation_config.repetition_penalty = (
request.parameters.repetition_penalty
@ -211,19 +218,11 @@ class Slot:
self._mask = attention_mask.clone()
self._selector = selector
def pause(self, reset_on_pause: bool):
def pause(self):
"""Mark the current slot as paused for generation.
Note that the KV cache for this slot will still be filled.
"""
if reset_on_pause:
# Drop the last token as it will be added back when resuming the slot
self._generated_tokens -= 1
# Since generated tokens are now part of the prefill, we need to reevaluate
# max_new_tokens for the next generation
self._generation_config.max_new_tokens = (
self._max_new_tokens - self._generated_tokens
)
self._state = Slot.State.PAUSE
def resume(self):
@ -340,16 +339,27 @@ class NeuronGenerator(Generator):
tokenizer: PreTrainedTokenizerBase,
):
self.model = model
self.rebuild_cache_on_prefill = not self.model.continuous_batching
if not isinstance(self.model, NeuronModelForCausalLM):
raise ValueError("The model must be a NeuronModelForCausalLM.")
if not model.neuron_config.continuous_batching:
raise ValueError(
"The neuron model must be compiled with continuous_batching=True."
)
# Specify padding and truncation options for decoder-only architecture
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"
self.tokenizer = tokenizer
self.special_tokens = self.tokenizer.all_special_ids
self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)]
self.slots = [
Slot(i, tokenizer) for i in range(self.model.neuron_config.batch_size)
]
self.batch_id = 0
@property
def on_device_sampling(self) -> bool:
return getattr(self.model.neuron_config, "on_device_sampling", False)
@property
def info(self) -> InfoResponse:
"""Returns the expected InfoResponse."""
@ -371,14 +381,22 @@ class NeuronGenerator(Generator):
The maximum number of tokens the model supports.
"""
# Just check that the warmup request parameters match the model capacity
batch_size = self.model.batch_size
batch_size = self.model.neuron_config.batch_size
if len(batch.requests) > batch_size:
raise ValueError(
f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE."
f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model.neuron_config.batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE."
)
self.prefill(batch)
self.clear()
return self.model.batch_size * self.model.max_length
return (
self.model.neuron_config.batch_size
* self.model.neuron_config.sequence_length
)
def max_prefill_length(self) -> int:
if hasattr(self.model.neuron_config, "max_context_length"):
return self.model.neuron_config.max_context_length
return self.model.neuron_config.sequence_length
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
"""Prefill new requests.
@ -398,7 +416,7 @@ class NeuronGenerator(Generator):
if len(empty_slots) < len(batch.requests):
raise ValueError(
f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots."
f" Please align max_batch_size with the static batch size: {self.model.batch_size}."
f" Please align max_batch_size with the static batch size: {self.model.neuron_config.batch_size}."
)
# Assign each request to an empty slot
logger.debug(
@ -412,12 +430,6 @@ class NeuronGenerator(Generator):
logger.debug(
f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
)
if self.rebuild_cache_on_prefill:
# We will clear pending slots and prefill all slots
prefill_slots = self.slots
seq_ids = None
else:
# We only need to pass inputs for the new requests
prefill_slots = new_slots
seq_ids = torch.tensor([slot.id for slot in prefill_slots])
# Reconstruct the full inputs (without padding) as seen by the model.
@ -431,8 +443,10 @@ class NeuronGenerator(Generator):
inputs.append(slot.cached_text)
# Apply truncation, making sure we fit into static dimensions
if slot.truncate == 0:
max_length = self.model.max_length
elif slot.truncate > max_length and slot.truncate < self.model.max_length:
max_length = self.max_prefill_length()
elif (
slot.truncate > max_length and slot.truncate < self.max_prefill_length()
):
max_length = slot.truncate
# Tokenize with padding and truncation
padded_inputs = self.tokenizer(
@ -444,13 +458,12 @@ class NeuronGenerator(Generator):
)
input_ids = padded_inputs.input_ids
attention_mask = padded_inputs.attention_mask
sampling_params = (
torch.zeros(input_ids.shape[0], 3) if self.on_device_sampling else None
)
# Pause previously active slots during generation
next_tokens = []
for slot in active_slots:
slot.pause(reset_on_pause=self.rebuild_cache_on_prefill)
if self.rebuild_cache_on_prefill:
# The slot will be reset, so we need to store its next token
next_tokens.append(slot.next_token)
slot.pause()
# Each slot must be reset with the padded inputs and masks
for i, slot in enumerate(prefill_slots):
if slot.state != slot.state.EMPTY:
@ -464,29 +477,33 @@ class NeuronGenerator(Generator):
slot_input_ids,
slot.generation_config,
self.model,
self.model.max_length,
self.model.neuron_config.sequence_length,
tokenizer=self.tokenizer,
seed=slot.seed,
)
slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
slot_attention_mask = attention_mask[i]
slot.reset(slot_input_ids, slot_attention_mask, selector)
if sampling_params is not None:
sampling_params[i, 0] = slot.generation_config.top_k
sampling_params[i, 1] = slot.generation_config.top_p
sampling_params[i, 2] = slot.generation_config.temperature
# Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored,
# as they have already been generated and sent back in the last decode.
model_inputs = self.model.prepare_inputs_for_prefill(
input_ids, attention_mask, seq_ids
input_ids,
attention_mask=attention_mask,
seq_ids=seq_ids,
sampling_params=sampling_params,
)
logits = self.model(**model_inputs)[0]
tokens_or_logits = self.model(**model_inputs)[0]
generation, next_batch = self._generate_token(
prefill_slots, self.batch_id, logits, input_ids
prefill_slots, self.batch_id, tokens_or_logits, input_ids
)
self.batch_id += 1
# Reactivate previously active slots for the next decode
for i, slot in enumerate(active_slots):
slot.resume()
if self.rebuild_cache_on_prefill:
# Append back the next token
slot.append(next_tokens[i])
logger.debug("Model ready for decoding")
if next_batch is not None:
logger.debug(
@ -530,12 +547,8 @@ class NeuronGenerator(Generator):
raise ValueError(
"Unable to decode tokens for non-prefilled batches (probably due to a previous failure)"
)
if self.model.continuous_batching:
decode_slots = active_slots
seq_ids = torch.tensor([slot.id for slot in decode_slots])
else:
decode_slots = self.slots
seq_ids = None
# Reconstruct input_ids and attention_mask from decode slots
n_slots = len(decode_slots)
input_ids = torch.full(
@ -545,22 +558,32 @@ class NeuronGenerator(Generator):
for slot in decode_slots:
max_length = max(max_length, slot.attention_mask.size(-1))
attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64)
sampling_params = torch.zeros(n_slots, 3) if self.on_device_sampling else None
for i, slot in enumerate(decode_slots):
if slot.state != Slot.State.EMPTY:
# input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
input_ids[i, 0] = slot.next_token
attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask
if sampling_params is not None:
sampling_params[i, 0] = slot.generation_config.top_k
sampling_params[i, 1] = slot.generation_config.top_p
sampling_params[i, 2] = slot.generation_config.temperature
model_inputs = self.model.prepare_inputs_for_decode(
input_ids, attention_mask, seq_ids
input_ids,
attention_mask=attention_mask,
seq_ids=seq_ids,
sampling_params=sampling_params,
)
tokens_or_logits = self.model(**model_inputs)[0]
return self._generate_token(
decode_slots, next_batch_id, tokens_or_logits, input_ids
)
logits = self.model(**model_inputs)[0]
return self._generate_token(decode_slots, next_batch_id, logits, input_ids)
def _generate_token(
self,
slots: List[Slot],
next_batch_id: int,
logits: torch.Tensor,
tokens_or_logits: torch.Tensor,
input_ids: torch.LongTensor,
) -> Tuple[List[Generation], CachedBatch]:
generations = []
@ -569,8 +592,11 @@ class NeuronGenerator(Generator):
if slot.state != Slot.State.READY:
continue
request_id = slot.request_id
next_token_logits = logits[i : i + 1, -1, :]
slot_input_ids = input_ids[i : i + 1, :]
if self.on_device_sampling:
next_token = tokens_or_logits[i]
else:
next_token_logits = tokens_or_logits[i : i + 1, -1, :]
next_token = slot.select(slot_input_ids, next_token_logits)
next_token_text = slot.append(next_token)
generated_text = None
@ -622,7 +648,7 @@ class NeuronGenerator(Generator):
def _cached_batch(self, batch_id: int, request_ids: List):
size = len(request_ids)
max_tokens = size * self.model.max_length
max_tokens = size * self.model.neuron_config.sequence_length
return CachedBatch(
id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens
)
@ -671,8 +697,16 @@ class NeuronGenerator(Generator):
Returns:
A NeuronGenerator.
"""
config = AutoConfig.from_pretrained(model_id)
neuron_config = getattr(config, "neuron", None)
try:
neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision)
except Exception as e:
logger.debug(
"NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
model_id,
revision,
e,
)
neuron_config = None
start = time.time()
if neuron_config is None:
export_kwargs = get_export_kwargs_from_env()

View File

@ -6,10 +6,12 @@ from typing import Optional
from huggingface_hub import snapshot_download
from huggingface_hub.constants import HF_HUB_CACHE
from loguru import logger
from transformers import AutoConfig
from optimum.neuron import NeuronModelForCausalLM
from optimum.neuron.utils import get_hub_cached_entries
from optimum.neuron.cache import get_hub_cached_entries
from optimum.neuron.configuration_utils import NeuronConfig
from .tgi_env import check_env_and_neuron_config_compatibility
def get_export_kwargs_from_env():
@ -24,7 +26,6 @@ def get_export_kwargs_from_env():
num_cores = int(num_cores)
auto_cast_type = os.environ.get("HF_AUTO_CAST_TYPE", None)
return {
"task": "text-generation",
"batch_size": batch_size,
"sequence_length": sequence_length,
"num_cores": num_cores,
@ -32,20 +33,15 @@ def get_export_kwargs_from_env():
}
def is_cached(model_id, neuron_config):
def is_cached(model_id):
# Look for cached entries for the specified model
in_cache = False
entries = get_hub_cached_entries(model_id, "inference")
entries = get_hub_cached_entries(model_id)
# Look for compatible entries
for entry in entries:
compatible = True
for key, value in neuron_config.items():
# Only weights can be different
if key in ["checkpoint_id", "checkpoint_revision"]:
continue
if entry[key] != value:
compatible = False
if compatible:
if check_env_and_neuron_config_compatibility(
entry, check_compiler_version=True
):
in_cache = True
break
return in_cache
@ -87,8 +83,16 @@ def fetch_model(
revision = None
# Download the model from the Hub (HUGGING_FACE_HUB_TOKEN must be set for a private or gated model)
# Note that the model may already be present in the cache.
config = AutoConfig.from_pretrained(model_id, revision=revision)
neuron_config = getattr(config, "neuron", None)
try:
neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision)
except Exception as e:
logger.debug(
"NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
model_id,
revision,
e,
)
neuron_config = None
if neuron_config is not None:
if os.path.isdir(model_id):
return model_id
@ -99,16 +103,11 @@ def fetch_model(
log_cache_size()
return snapshot_download(model_id, revision=revision, ignore_patterns="*.bin")
# Model needs to be exported: look for compatible cached entries on the hub
export_kwargs = get_export_kwargs_from_env()
export_config = NeuronModelForCausalLM.get_export_config(
model_id, config, revision=revision, **export_kwargs
)
neuron_config = export_config.neuron
if not is_cached(model_id, neuron_config):
if not is_cached(model_id):
hub_cache_url = "https://huggingface.co/aws-neuron/optimum-neuron-cache"
neuron_export_url = "https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-neuronx-tgi"
error_msg = (
f"No cached version found for {model_id} with {neuron_config}."
f"No cached version found for {model_id} with {get_export_kwargs_from_env()}."
f"You can start a discussion to request it on {hub_cache_url}"
f"Alternatively, you can export your own neuron model as explained in {neuron_export_url}"
)
@ -121,8 +120,10 @@ def fetch_model(
# Prefetch weights, tokenizer and generation config so that they are in cache
log_cache_size()
start = time.time()
snapshot_download(model_id, revision=revision, ignore_patterns="*.bin")
snapshot_path = snapshot_download(
model_id, revision=revision, ignore_patterns="*.bin"
)
end = time.time()
logger.info(f"Model weights fetched in {end - start:.2f} s.")
log_cache_size()
return model_id
return snapshot_path

View File

@ -6,12 +6,11 @@ import os
import sys
from typing import Any, Dict, List, Optional
from huggingface_hub import constants
from transformers import AutoConfig
from optimum.neuron.modeling_decoder import get_available_cores
from optimum.neuron.utils import get_hub_cached_entries
from optimum.neuron.cache import get_hub_cached_entries
from optimum.neuron.configuration_utils import NeuronConfig
from optimum.neuron.utils.version_utils import get_neuronxcc_version
from optimum.neuron.utils import map_torch_dtype
logger = logging.getLogger(__name__)
@ -24,15 +23,9 @@ tgi_router_env_vars = [
]
tgi_server_env_vars = ["HF_NUM_CORES", "HF_AUTO_CAST_TYPE"]
env_config_peering = [
("MAX_BATCH_SIZE", "batch_size"),
("MAX_TOTAL_TOKENS", "sequence_length"),
("HF_AUTO_CAST_TYPE", "auto_cast_type"),
("HF_NUM_CORES", "num_cores"),
]
# By the end of this script all env var should be specified properly
env_vars = tgi_server_env_vars + tgi_router_env_vars
tgi_env_vars = tgi_server_env_vars + tgi_router_env_vars
available_cores = get_available_cores()
neuronxcc_version = get_neuronxcc_version()
@ -93,9 +86,17 @@ def parse_cmdline_and_set_env(argv: List[str] = None) -> argparse.Namespace:
def neuron_config_to_env(neuron_config):
if isinstance(neuron_config, NeuronConfig):
neuron_config = neuron_config.to_dict()
with open(os.environ["ENV_FILEPATH"], "w") as f:
for env_var, config_key in env_config_peering:
f.write("export {}={}\n".format(env_var, neuron_config[config_key]))
f.write("export MAX_BATCH_SIZE={}\n".format(neuron_config["batch_size"]))
f.write("export MAX_TOTAL_TOKENS={}\n".format(neuron_config["sequence_length"]))
f.write("export HF_NUM_CORES={}\n".format(neuron_config["tp_degree"]))
config_key = (
"auto_cast_type" if "auto_cast_type" in neuron_config else "torch_dtype"
)
auto_cast_type = neuron_config[config_key]
f.write("export HF_AUTO_CAST_TYPE={}\n".format(auto_cast_type))
max_input_tokens = os.getenv("MAX_INPUT_TOKENS")
if not max_input_tokens:
max_input_tokens = int(neuron_config["sequence_length"]) // 2
@ -111,7 +112,7 @@ def neuron_config_to_env(neuron_config):
def sort_neuron_configs(dictionary):
return -dictionary["num_cores"], -dictionary["batch_size"]
return -dictionary["tp_degree"], -dictionary["batch_size"]
def lookup_compatible_cached_model(
@ -119,7 +120,7 @@ def lookup_compatible_cached_model(
) -> Optional[Dict[str, Any]]:
# Reuse the same mechanic as the one in use to configure the tgi server part
# The only difference here is that we stay as flexible as possible on the compatibility part
entries = get_hub_cached_entries(model_id, "inference")
entries = get_hub_cached_entries(model_id)
logger.debug(
"Found %d cached entries for model %s, revision %s",
@ -155,15 +156,15 @@ def lookup_compatible_cached_model(
def check_env_and_neuron_config_compatibility(
neuron_config: Dict[str, Any], check_compiler_version: bool
neuron_config_dict: Dict[str, Any], check_compiler_version: bool
) -> bool:
logger.debug(
"Checking the provided neuron config %s is compatible with the local setup and provided environment",
neuron_config,
neuron_config_dict,
)
# Local setup compat checks
if neuron_config["num_cores"] > available_cores:
if neuron_config_dict["tp_degree"] > available_cores:
logger.debug(
"Not enough neuron cores available to run the provided neuron config"
)
@ -171,33 +172,65 @@ def check_env_and_neuron_config_compatibility(
if (
check_compiler_version
and neuron_config["compiler_version"] != neuronxcc_version
and neuron_config_dict["neuronxcc_version"] != neuronxcc_version
):
logger.debug(
"Compiler version conflict, the local one (%s) differs from the one used to compile the model (%s)",
neuronxcc_version,
neuron_config["compiler_version"],
neuron_config_dict["neuronxcc_version"],
)
return False
for env_var, config_key in env_config_peering:
neuron_config_value = str(neuron_config[config_key])
env_value = os.getenv(env_var, str(neuron_config_value))
batch_size = os.getenv("MAX_BATCH_SIZE", None)
if batch_size is not None and neuron_config_dict["batch_size"] < int(batch_size):
logger.debug(
"The provided MAX_BATCH_SIZE (%s) is higher than the neuron config batch size (%s)",
os.getenv("MAX_BATCH_SIZE"),
neuron_config_dict["batch_size"],
)
return False
max_total_tokens = os.getenv("MAX_TOTAL_TOKENS", None)
if max_total_tokens is not None and neuron_config_dict["sequence_length"] < int(
max_total_tokens
):
logger.debug(
"The provided MAX_TOTAL_TOKENS (%s) is higher than the neuron config sequence length (%s)",
max_total_tokens,
neuron_config_dict["sequence_length"],
)
return False
num_cores = os.getenv("HF_NUM_CORES", None)
if num_cores is not None and neuron_config_dict["tp_degree"] < int(num_cores):
logger.debug(
"The provided HF_NUM_CORES (%s) is higher than the neuron config tp degree (%s)",
num_cores,
neuron_config_dict["tp_degree"],
)
return False
auto_cast_type = os.getenv("HF_AUTO_CAST_TYPE", None)
if auto_cast_type is not None:
config_key = (
"auto_cast_type"
if "auto_cast_type" in neuron_config_dict
else "torch_dtype"
)
neuron_config_value = map_torch_dtype(str(neuron_config_dict[config_key]))
env_value = map_torch_dtype(auto_cast_type)
if env_value != neuron_config_value:
logger.debug(
"The provided env var '%s' and the neuron config '%s' param differ (%s != %s)",
env_var,
config_key,
"The provided auto cast type and the neuron config param differ (%s != %s)",
env_value,
neuron_config_value,
)
return False
max_input_tokens = int(
os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0))
)
if max_input_tokens > 0:
sequence_length = neuron_config["sequence_length"]
if hasattr(neuron_config_dict, "max_context_length"):
sequence_length = neuron_config_dict["max_context_length"]
else:
sequence_length = neuron_config_dict["sequence_length"]
if max_input_tokens >= sequence_length:
logger.debug(
"Specified max input tokens is not compatible with config sequence length ( %s >= %s)",
@ -211,38 +244,29 @@ def check_env_and_neuron_config_compatibility(
def get_env_dict() -> Dict[str, str]:
d = {}
for k in env_vars:
for k in tgi_env_vars:
d[k] = os.getenv(k)
return d
def main():
"""
This script determines proper default TGI env variables for the neuron precompiled models to
work properly
:return:
"""
args = parse_cmdline_and_set_env()
for env_var in env_vars:
if not os.getenv(env_var):
break
else:
logger.info(
"All env vars %s already set, skipping, user know what they are doing",
env_vars,
def get_neuron_config_for_model(
model_name_or_path: str, revision: Optional[str] = None
) -> NeuronConfig:
try:
neuron_config = NeuronConfig.from_pretrained(
model_name_or_path, revision=revision
)
sys.exit(0)
cache_dir = constants.HF_HUB_CACHE
logger.info("Cache dir %s, model %s", cache_dir, args.model_id)
config = AutoConfig.from_pretrained(args.model_id, revision=args.revision)
neuron_config = getattr(config, "neuron", None)
except Exception as e:
logger.debug(
"NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
model_name_or_path,
revision,
e,
)
neuron_config = None
if neuron_config is not None:
compatible = check_env_and_neuron_config_compatibility(
neuron_config, check_compiler_version=False
neuron_config.to_dict(), check_compiler_version=False
)
if not compatible:
env_dict = get_env_dict()
@ -252,17 +276,6 @@ def main():
logger.error(msg)
raise Exception(msg)
else:
neuron_config = lookup_compatible_cached_model(args.model_id, args.revision)
neuron_config = lookup_compatible_cached_model(model_name_or_path, revision)
if not neuron_config:
msg = (
"No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}"
).format(get_env_dict(), available_cores, neuronxcc_version)
logger.error(msg)
raise Exception(msg)
neuron_config_to_env(neuron_config)
if __name__ == "__main__":
main()
return neuron_config

View File

@ -4,14 +4,12 @@ import subprocess
import sys
from tempfile import TemporaryDirectory
import huggingface_hub
import os
import pytest
from transformers import AutoTokenizer
from optimum.neuron import NeuronModelForCausalLM
from optimum.neuron.utils import synchronize_hub_cache
from optimum.neuron.version import __sdk_version__ as sdk_version
from optimum.neuron.version import __version__ as version
from optimum.neuron.cache import synchronize_hub_cache
logging.basicConfig(
@ -21,30 +19,14 @@ logging.basicConfig(
)
logger = logging.getLogger(__file__)
OPTIMUM_CACHE_REPO_ID = "optimum-internal-testing/neuron-testing-cache"
# All model configurations below will be added to the neuron_model_config fixture
MODEL_CONFIGURATIONS = {
"gpt2": {
"model_id": "gpt2",
"export_kwargs": {
"batch_size": 4,
"sequence_length": 1024,
"num_cores": 2,
"auto_cast_type": "fp16",
},
},
"llama": {
"model_id": "NousResearch/Hermes-2-Theta-Llama-3-8B",
"export_kwargs": {
"batch_size": 4,
"sequence_length": 2048,
"num_cores": 2,
"auto_cast_type": "fp16",
},
},
"mistral": {
"model_id": "optimum/mistral-1.1b-testing",
"model_id": "unsloth/Llama-3.2-1B-Instruct",
"export_kwargs": {
"batch_size": 4,
"sequence_length": 4096,
@ -58,7 +40,7 @@ MODEL_CONFIGURATIONS = {
"batch_size": 4,
"sequence_length": 4096,
"num_cores": 2,
"auto_cast_type": "fp16",
"auto_cast_type": "bf16",
},
},
"granite": {
@ -73,12 +55,6 @@ MODEL_CONFIGURATIONS = {
}
def get_hub_neuron_model_id(config_name: str):
return (
f"optimum-internal-testing/neuron-testing-{version}-{sdk_version}-{config_name}"
)
def export_model(model_id, export_kwargs, neuron_model_path):
export_command = [
"optimum-cli",
@ -104,57 +80,35 @@ def export_model(model_id, export_kwargs, neuron_model_path):
def neuron_model_config(request):
"""Expose a pre-trained neuron model
The fixture first makes sure the following model artifacts are present on the hub:
- exported neuron model under optimum-internal-testing/neuron-testing-<version>-<name>,
- cached artifacts under optimum-internal-testing/neuron-testing-cache.
If not, it will export the model and push it to the hub.
It then fetches the model locally and return a dictionary containing:
The fixture exports a model locally and returns a dictionary containing:
- a configuration name,
- the original model id,
- the export parameters,
- the neuron model id,
- the neuron model local path.
For each exposed model, the local directory is maintained for the duration of the
test session and cleaned up afterwards.
The hub model artifacts are never cleaned up and persist accross sessions.
They must be cleaned up manually when the optimum-neuron version changes.
"""
config_name = request.param
model_config = copy.deepcopy(MODEL_CONFIGURATIONS[request.param])
model_id = model_config["model_id"]
export_kwargs = model_config["export_kwargs"]
neuron_model_id = get_hub_neuron_model_id(config_name)
with TemporaryDirectory() as neuron_model_path:
hub = huggingface_hub.HfApi()
if hub.repo_exists(neuron_model_id):
logger.info(f"Fetching {neuron_model_id} from the HuggingFace hub")
hub.snapshot_download(neuron_model_id, local_dir=neuron_model_path)
else:
export_model(model_id, export_kwargs, neuron_model_path)
synchronize_hub_cache(cache_repo_id=OPTIMUM_CACHE_REPO_ID)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.save_pretrained(neuron_model_path)
del tokenizer
# Create the test model on the hub
hub.create_repo(neuron_model_id, private=True)
hub.upload_folder(
folder_path=neuron_model_path,
repo_id=neuron_model_id,
ignore_patterns=[NeuronModelForCausalLM.CHECKPOINT_DIR + "/*"],
)
# Make sure it is cached
synchronize_hub_cache(cache_repo_id=OPTIMUM_CACHE_REPO_ID)
# Add dynamic parameters to the model configuration
model_config["neuron_model_path"] = neuron_model_path
model_config["neuron_model_id"] = neuron_model_id
# Also add model configuration name to allow tests to adapt their expectations
model_config["name"] = config_name
# Yield instead of returning to keep a reference to the temporary directory.
# It will go out of scope and be released only once all tests needing the fixture
# have been completed.
logger.info(f"{config_name} ready for testing ...")
os.environ["CUSTOM_CACHE_REPO"] = OPTIMUM_CACHE_REPO_ID
yield model_config
logger.info(f"Done with {config_name}")

View File

@ -0,0 +1,42 @@
import os
import pytest
from text_generation_server.generator import NeuronGenerator
from text_generation_server.model import fetch_model, is_cached
@pytest.fixture(scope="module")
def cached_model_id(neuron_model_config) -> str:
"""
Fixture to provide a cached model ID for testing.
This assumes the model is already cached in the local environment.
"""
export_kwargs = neuron_model_config["export_kwargs"]
os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"])
os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"])
os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"]
os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"])
yield neuron_model_config["model_id"]
os.environ.pop("MAX_BATCH_SIZE", None)
os.environ.pop("MAX_TOTAL_TOKENS", None)
os.environ.pop("HF_AUTO_CAST_TYPE", None)
os.environ.pop("HF_NUM_CORES", None)
def test_model_is_cached(cached_model_id):
assert is_cached(cached_model_id), f"Model {cached_model_id} is not cached"
def test_fetch_cached_model(cached_model_id: str):
model_path = fetch_model(cached_model_id)
assert os.path.exists(
model_path
), f"Model {cached_model_id} was not fetched successfully"
assert os.path.isdir(model_path), f"Model {cached_model_id} is not a directory"
def test_generator_from_cached_model(cached_model_id: str):
generator = NeuronGenerator.from_pretrained(model_id=cached_model_id)
assert generator is not None, "Generator could not be created from cached model"
assert generator.model is not None, "Generator model is not initialized"
assert generator.tokenizer is not None, "Generator tokenizer is not initialized"

View File

@ -9,13 +9,13 @@ def test_continuous_batching_two_requests(neuron_model_config):
"""
neuron_model_path = neuron_model_config["neuron_model_path"]
generator = NeuronGenerator.from_pretrained(neuron_model_path)
assert generator.model.batch_size > 1
assert generator.model.neuron_config.batch_size > 1
input_text = "Once upon a time"
max_new_tokens = 20
# Prefill a single request, remembering the generated token
tokens = {0: [], 1: []}
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens)
max_length = generator.model.max_length
max_length = generator.model.neuron_config.sequence_length
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
generations, next_batch = generator.prefill(batch)
assert next_batch.size == 1

View File

@ -23,7 +23,7 @@ def _test_decode(config_name, generator, do_sample):
request = create_request(
id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample
)
max_length = generator.model.max_length
max_length = generator.model.neuron_config.sequence_length
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
generations, next_batch = generator.prefill(batch)
# We already generated one token: call decode max_new_tokens - 1 times
@ -40,19 +40,15 @@ def _test_decode(config_name, generator, do_sample):
assert output.finish_reason == 0
if do_sample:
expected_text = {
"gpt2": " The sun was set",
"llama": "George Orwell, 1984",
"mistral": "The sky was",
"qwen2": " A young woman with",
"llama": " I sat alone in the café",
"qwen2": " The air was so still",
"granite": "1984, George Orwell",
}[config_name]
assert expected_text in output.text
else:
print(output.text)
expected_text = {
"gpt2": '\n\n"I\'m going to go to bed," I said.\n\n"I\'m going',
"llama": " George Orwells classic dystopian novel, 1984, begins with this ominous sentence. The story",
"mistral": "\nThe clocks were striking thirteen.\nThe clocks were striking thirteen.",
"llama": " The world was holding its breath as the world's top scientists and engineers gathered at the secret underground facility",
"qwen2": " I was sitting in my room, staring at the ceiling, when the door opened and in came a",
"granite": "\n\nThis opening line from George Orwell's dystopian novel \"198",
}[config_name]

View File

@ -9,7 +9,7 @@ def test_prefill(neuron_model_config):
neuron_model_path = neuron_model_config["neuron_model_path"]
generator = NeuronGenerator.from_pretrained(neuron_model_path)
max_batch_size = 4
assert generator.model.batch_size >= max_batch_size
assert generator.model.neuron_config.batch_size >= max_batch_size
for num_requests in [1, max_batch_size]:
for do_sample in [True, False]:
mode = "sample" if do_sample else "greedy"
@ -34,7 +34,7 @@ def _test_prefill(config_name, generator, batch_size, do_sample):
)
)
# Let's be pessimistic when estimating max_tokens
max_length = generator.model.max_length
max_length = generator.max_prefill_length()
batch = Batch(
id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
)
@ -46,17 +46,13 @@ def _test_prefill(config_name, generator, batch_size, do_sample):
assert len(generations) == batch_size
if do_sample:
expectations = {
"gpt2": [383, " The"],
"llama": [10058, " George"],
"mistral": [450, " The"],
"qwen2": [362, " A"],
"llama": [358, " I"],
"qwen2": [576, " The"],
"granite": [308, " ("],
}[config_name]
else:
expectations = {
"gpt2": [198, "\n"],
"llama": [10058, " George"],
"mistral": [13, "\n"],
"llama": [578, " The"],
"qwen2": [358, " I"],
"granite": [203, "\n"],
}[config_name]
@ -70,7 +66,7 @@ def test_prefill_truncate(neuron_model_config):
config_name = neuron_model_config["name"]
neuron_model_path = neuron_model_config["neuron_model_path"]
generator = NeuronGenerator.from_pretrained(neuron_model_path)
batch_size = generator.model.batch_size
batch_size = generator.model.neuron_config.batch_size
# We apply truncation to all requests but the first one
truncate = [
None,
@ -83,7 +79,7 @@ def test_prefill_truncate(neuron_model_config):
requests = []
for i in range(batch_size):
requests.append(create_request(id=i, inputs=input_text, truncate=truncate[i]))
max_length = generator.model.max_length
max_length = generator.max_prefill_length()
batch = Batch(
id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
)
@ -91,12 +87,12 @@ def test_prefill_truncate(neuron_model_config):
# Even if the input text is identical for all requests, the first generated token might
# be different because of the truncation
expectations = {
"gpt2": [" He", " He", "\n", " He"],
"llama": ["", " The", " He", " He"],
"mistral": [" He", "\n", " He", " He"],
"llama": [" He", "iens", "\x08", " He"],
"qwen2": [" He", " The", " He", " He"],
"granite": ["\n", "\n", " I", " He"],
}[config_name]
for i, g in enumerate(generations):
tokens = g.tokens
assert tokens.texts[0] == expectations[i]
assert (
tokens.texts[0] == expectations[i]
), f"Request {i} expected [{expectations[i]}], got [{tokens.texts[0]}]"

View File

@ -0,0 +1,63 @@
import os
import pytest
from tempfile import TemporaryDirectory
from optimum.neuron.models.inference.nxd.backend.config import NxDNeuronConfig
from optimum.neuron.utils import map_torch_dtype
from text_generation_server.tgi_env import (
get_neuron_config_for_model,
lookup_compatible_cached_model,
neuron_config_to_env,
)
def test_get_neuron_config_for_model(neuron_model_config):
neuron_model_path = neuron_model_config["neuron_model_path"]
export_kwargs = neuron_model_config["export_kwargs"]
os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"])
os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"])
os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"]
os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"])
neuron_config = get_neuron_config_for_model(neuron_model_path)
assert neuron_config is not None
assert neuron_config.batch_size == export_kwargs["batch_size"]
assert neuron_config.sequence_length == export_kwargs["sequence_length"]
assert neuron_config.tp_degree == export_kwargs["num_cores"]
if isinstance(neuron_config, NxDNeuronConfig):
assert map_torch_dtype(neuron_config.torch_dtype) == map_torch_dtype(
export_kwargs["auto_cast_type"]
)
else:
assert map_torch_dtype(neuron_config.auto_cast_type) == map_torch_dtype(
export_kwargs["auto_cast_type"]
)
@pytest.mark.parametrize("model_id", ["unsloth/Llama-3.2-1B-Instruct"])
def test_lookup_compatible_cached_model(model_id: str):
neuron_config = lookup_compatible_cached_model(model_id, None)
assert neuron_config is not None
def test_neuron_config_to_env(neuron_model_config) -> None:
neuron_model_path = neuron_model_config["neuron_model_path"]
neuron_config = get_neuron_config_for_model(neuron_model_path)
with TemporaryDirectory() as temp_dir:
os.environ["ENV_FILEPATH"] = os.path.join(temp_dir, "env.sh")
neuron_config_to_env(neuron_config)
with open(os.environ["ENV_FILEPATH"], "r") as env_file:
env_content = env_file.read()
assert f"export MAX_BATCH_SIZE={neuron_config.batch_size}" in env_content
assert (
f"export MAX_TOTAL_TOKENS={neuron_config.sequence_length}"
in env_content
)
assert f"export HF_NUM_CORES={neuron_config.tp_degree}" in env_content
if hasattr(neuron_config, "torch_dtype"):
auto_cast_type = str(map_torch_dtype(neuron_config.torch_dtype)).split(
"."
)[-1]
else:
auto_cast_type = neuron_config.auto_cast_type
assert f"export HF_AUTO_CAST_TYPE={auto_cast_type}" in env_content

View File

@ -9,7 +9,7 @@ touch $ENV_FILEPATH
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
${SCRIPT_DIR}/tgi_env.py $@
${SCRIPT_DIR}/tgi_entry_point.py $@
source $ENV_FILEPATH

View File

@ -0,0 +1,53 @@
#!/usr/bin/env python
import logging
import os
import sys
from text_generation_server.tgi_env import (
available_cores,
get_env_dict,
get_neuron_config_for_model,
neuron_config_to_env,
neuronxcc_version,
parse_cmdline_and_set_env,
tgi_env_vars,
)
logger = logging.getLogger(__name__)
def main():
"""
This script determines proper default TGI env variables for the neuron precompiled models to
work properly
:return:
"""
args = parse_cmdline_and_set_env()
for env_var in tgi_env_vars:
if not os.getenv(env_var):
break
else:
logger.info(
"All env vars %s already set, skipping, user know what they are doing",
tgi_env_vars,
)
sys.exit(0)
neuron_config = get_neuron_config_for_model(args.model_id, args.revision)
if not neuron_config:
msg = (
"No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}"
).format(get_env_dict(), available_cores, neuronxcc_version)
logger.error(msg)
raise Exception(msg)
neuron_config_to_env(neuron_config)
if __name__ == "__main__":
main()

View File

@ -162,6 +162,11 @@ impl Allocator for SimpleAllocator {
tokens: u32,
_prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> {
let mut tokens = tokens;
if self.is_hpu_device {
// need 1 slot for ping-pong optimization
tokens += 1;
}
// Apply window size
let (required_blocks, repeats) = {
let (tokens, repeats) = match self.window_size {
@ -176,8 +181,7 @@ impl Allocator for SimpleAllocator {
let required_blocks = tokens.div_ceil(self.block_size);
(required_blocks, repeats)
};
let mut tokens = tokens as usize;
let tokens = tokens as usize;
if required_blocks > self.free_blocks.len() as u32 {
None
} else {
@ -189,8 +193,6 @@ impl Allocator for SimpleAllocator {
.split_off(self.free_blocks.len() - required_blocks as usize);
if self.is_hpu_device {
blocks.sort();
// need 1 slot for ping-pong optimization
tokens += 1;
}
let mut slots =
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);

View File

@ -8,6 +8,7 @@ use std::cmp::max;
use std::collections::VecDeque;
use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse;
use text_generation_router::usage_stats::Env;
use text_generation_router::validation::{
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
ValidStoppingParameters,
@ -185,6 +186,9 @@ struct State {
/// Paged Attention Block Allocation
block_allocator: Option<BlockAllocator>,
/// indicate if it's hpu device, the hpu device needs padding to generate first token.
is_hpu_device: bool,
}
impl State {
@ -214,6 +218,7 @@ impl State {
speculate,
support_chunking,
block_allocator,
is_hpu_device: Env::new().is_hpu_device(),
}
}
@ -368,6 +373,21 @@ impl State {
}
}
if self.is_hpu_device {
//HPU needs to pad for the prefill
max_input_length = max_input_length.max(entry.request.input_length);
let actual_prefill_tokens_for_hpu =
(batch.len() + 1) as u32 * max_input_length;
if actual_prefill_tokens_for_hpu > prefill_token_budget {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={actual_prefill_tokens_for_hpu} > {prefill_token_budget}");
self.entries.push_front((id, entry));
break 'entry_loop;
}
}
prefill_tokens += postfix_len;
Some(block_allocation)

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
"version": "3.3.0-dev0"
"version": "3.3.2-dev0"
},
"paths": {
"/": {

View File

@ -20,7 +20,7 @@ hf_token=YOUR_HF_ACCESS_TOKEN
docker run --runtime=habana --cap-add=sys_nice --ipc=host \
-p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \
ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \
--model-id $model
```
@ -52,7 +52,7 @@ hf_token=YOUR_ACCESS_TOKEN
docker run --runtime=habana --cap-add=sys_nice --ipc=host \
-p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \
ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \
--model-id $model
<text-generation-inference-launcher-arguments>
```
@ -115,7 +115,7 @@ docker run -p 8080:80 \
-e BATCH_BUCKET_SIZE=256 \
-e PREFILL_BATCH_BUCKET_SIZE=4 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \
ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 1024 --max-total-tokens 2048 \
@ -141,7 +141,7 @@ docker run -p 8080:80 \
-v $volume:/data \
-e PREFILL_BATCH_BUCKET_SIZE=1 \
-e BATCH_BUCKET_SIZE=1 \
ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \
ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \
--model-id $model \
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
--max-total-tokens 8192 --max-batch-size 4
@ -208,7 +208,7 @@ docker run --runtime=habana --ipc=host --cap-add=sys_nice \
-e PROF_PATH=/tmp/hpu_profile \
-e PROF_RANKS=0 \
-e PROF_RECORD_SHAPES=True \
ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \
ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \
--model-id $model
```

View File

@ -31,7 +31,7 @@ deployment instructions in the model card:
The service is launched simply by running the text-generation-inference container with two sets of parameters:
```
docker run <system_parameters> ghcr.io/huggingface/text-generation-inference:3.3.0-neuron <service_parameters>
docker run <system_parameters> ghcr.io/huggingface/text-generation-inference:3.3.2-neuron <service_parameters>
```
- system parameters are used to map ports, volumes and devices between the host and the service,

View File

@ -19,6 +19,6 @@ docker run --gpus all \
--shm-size 1g \
-e HF_TOKEN=$token \
-p 8080:80 \
-v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0 \
-v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2 \
--model-id $model
```

View File

@ -19,7 +19,7 @@ bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models.
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model --quantize bitsandbytes
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model --quantize bitsandbytes
```
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
@ -27,7 +27,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model --quantize bitsandbytes-nf4
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model --quantize bitsandbytes-nf4
```
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
@ -48,7 +48,7 @@ $$({\hat{W}_{l}}^{*} = argmin_{\hat{W_{l}}} ||W_{l}X-\hat{W}_{l}X||^{2}_{2})$$
TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇
```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model --quantize gptq
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model --quantize gptq
```
Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI.

Some files were not shown because too many files have changed in this diff Show More