Merge branch 'huggingface:main' into qwen3_moe

This commit is contained in:
Yuan Wu 2025-06-13 10:02:18 +08:00 committed by GitHub
commit 1791c855f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
89 changed files with 1789 additions and 7317 deletions

16
Cargo.lock generated
View File

@ -4650,7 +4650,7 @@ dependencies = [
[[package]]
name = "text-generation-backends-trtllm"
version = "3.3.1-dev0"
version = "3.3.2-dev0"
dependencies = [
"async-trait",
"clap 4.5.32",
@ -4671,7 +4671,7 @@ dependencies = [
[[package]]
name = "text-generation-benchmark"
version = "3.3.1-dev0"
version = "3.3.2-dev0"
dependencies = [
"average",
"clap 4.5.32",
@ -4691,7 +4691,7 @@ dependencies = [
[[package]]
name = "text-generation-client"
version = "3.3.1-dev0"
version = "3.3.2-dev0"
dependencies = [
"async-trait",
"base64 0.22.1",
@ -4709,7 +4709,7 @@ dependencies = [
[[package]]
name = "text-generation-launcher"
version = "3.3.1-dev0"
version = "3.3.2-dev0"
dependencies = [
"clap 4.5.32",
"ctrlc",
@ -4730,7 +4730,7 @@ dependencies = [
[[package]]
name = "text-generation-router"
version = "3.3.1-dev0"
version = "3.3.2-dev0"
dependencies = [
"anyhow",
"async-stream",
@ -4782,7 +4782,7 @@ dependencies = [
[[package]]
name = "text-generation-router-llamacpp"
version = "3.3.1-dev0"
version = "3.3.2-dev0"
dependencies = [
"async-trait",
"bindgen 0.71.1",
@ -4800,7 +4800,7 @@ dependencies = [
[[package]]
name = "text-generation-router-v2"
version = "3.3.1-dev0"
version = "3.3.2-dev0"
dependencies = [
"async-stream",
"async-trait",
@ -4849,7 +4849,7 @@ dependencies = [
[[package]]
name = "text-generation-router-v3"
version = "3.3.1-dev0"
version = "3.3.2-dev0"
dependencies = [
"async-stream",
"async-trait",

View File

@ -21,7 +21,7 @@ default-members = [
resolver = "2"
[workspace.package]
version = "3.3.1-dev0"
version = "3.3.2-dev0"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -5,7 +5,7 @@ RUN mkdir -p /tgi
# Fetch the optimum-neuron sources directly to avoid relying on pypi deployments
FROM alpine AS optimum-neuron
RUN mkdir -p /optimum-neuron
ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.1.0.tar.gz /optimum-neuron/sources.tar.gz
ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.2.0.tar.gz /optimum-neuron/sources.tar.gz
RUN tar -C /optimum-neuron -xf /optimum-neuron/sources.tar.gz --strip-components=1
# Build cargo components (adapted from TGI original Dockerfile)
@ -108,10 +108,10 @@ RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEU
# Install neuronx packages
RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \
aws-neuronx-dkms=2.19.64.0 \
aws-neuronx-collectives=2.23.135.0-3e70920f2 \
aws-neuronx-runtime-lib=2.23.112.0-9b5179492 \
aws-neuronx-tools=2.20.204.0 \
aws-neuronx-dkms=2.20.28.0 \
aws-neuronx-collectives=2.24.59.0-838c7fc8b \
aws-neuronx-runtime-lib=2.24.53.0-f239092cc \
aws-neuronx-tools=2.22.61.0 \
libxml2 \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
@ -125,11 +125,10 @@ RUN pip3 install \
--index-url https://download.pytorch.org/whl/cpu
RUN pip3 install \
neuronx-cc==2.16.372.0 \
torch-neuronx==2.5.1.2.4.0 \
transformers-neuronx==0.13.322 \
neuronx-distributed==0.10.1 \
libneuronxla==2.1.681.0 \
neuronx-cc==2.17.194.0 \
torch-neuronx==2.5.1.2.6.0 \
neuronx-distributed==0.11.0 \
libneuronxla==2.2.1630.0 \
--extra-index-url=https://pip.repos.neuron.amazonaws.com
# Install HuggingFace packages
@ -160,7 +159,7 @@ RUN pip install dist/text_generation_server*.tar.gz
# Final image
FROM neuron
COPY backends/neuron/tgi_env.py /tgi_env.py
COPY backends/neuron/tgi_entry_point.py /tgi_entry_point.py
COPY backends/neuron/tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh

View File

@ -57,7 +57,7 @@ ARG PYTORCH_VERSION
FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base
ENV ATTENTION=default
ENV ATTENTION=paged
ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0
ENV PT_HPU_LAZY_MODE=1

View File

@ -84,7 +84,7 @@ model=HuggingFaceH4/zephyr-7b-beta
volume=$PWD/data
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model
ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model
```
And then you can make requests like
@ -121,7 +121,7 @@ curl localhost:8080/v1/chat/completions \
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1-rocm --model-id $model` instead of the command above.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2-rocm --model-id $model` instead of the command above.
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
```
@ -152,7 +152,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
token=<your cli READ token>
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model
ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model
```
### A note on Shared Memory (shm)

View File

@ -1058,199 +1058,6 @@ files = [
{file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"},
]
[[package]]
name = "nvidia-cublas-cu12"
version = "12.4.5.8"
description = "CUBLAS native runtime libraries"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3"},
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b"},
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-win_amd64.whl", hash = "sha256:5a796786da89203a0657eda402bcdcec6180254a8ac22d72213abc42069522dc"},
]
[[package]]
name = "nvidia-cuda-cupti-cu12"
version = "12.4.127"
description = "CUDA profiling tools runtime libs."
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a"},
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb"},
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922"},
]
[[package]]
name = "nvidia-cuda-nvrtc-cu12"
version = "12.4.127"
description = "NVRTC native runtime libraries"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0eedf14185e04b76aa05b1fea04133e59f465b6f960c0cbf4e37c3cb6b0ea198"},
{file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338"},
{file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:a961b2f1d5f17b14867c619ceb99ef6fcec12e46612711bcec78eb05068a60ec"},
]
[[package]]
name = "nvidia-cuda-runtime-cu12"
version = "12.4.127"
description = "CUDA Runtime native Libraries"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3"},
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5"},
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:09c2e35f48359752dfa822c09918211844a3d93c100a715d79b59591130c5e1e"},
]
[[package]]
name = "nvidia-cudnn-cu12"
version = "9.1.0.70"
description = "cuDNN runtime libraries"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f"},
{file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a"},
]
[package.dependencies]
nvidia-cublas-cu12 = "*"
[[package]]
name = "nvidia-cufft-cu12"
version = "11.2.1.3"
description = "CUFFT native runtime libraries"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399"},
{file = "nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9"},
{file = "nvidia_cufft_cu12-11.2.1.3-py3-none-win_amd64.whl", hash = "sha256:d802f4954291101186078ccbe22fc285a902136f974d369540fd4a5333d1440b"},
]
[package.dependencies]
nvidia-nvjitlink-cu12 = "*"
[[package]]
name = "nvidia-curand-cu12"
version = "10.3.5.147"
description = "CURAND native runtime libraries"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1f173f09e3e3c76ab084aba0de819c49e56614feae5c12f69883f4ae9bb5fad9"},
{file = "nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b"},
{file = "nvidia_curand_cu12-10.3.5.147-py3-none-win_amd64.whl", hash = "sha256:f307cc191f96efe9e8f05a87096abc20d08845a841889ef78cb06924437f6771"},
]
[[package]]
name = "nvidia-cusolver-cu12"
version = "11.6.1.9"
description = "CUDA solver native runtime libraries"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e"},
{file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260"},
{file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-win_amd64.whl", hash = "sha256:e77314c9d7b694fcebc84f58989f3aa4fb4cb442f12ca1a9bde50f5e8f6d1b9c"},
]
[package.dependencies]
nvidia-cublas-cu12 = "*"
nvidia-cusparse-cu12 = "*"
nvidia-nvjitlink-cu12 = "*"
[[package]]
name = "nvidia-cusparse-cu12"
version = "12.3.1.170"
description = "CUSPARSE native runtime libraries"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3"},
{file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1"},
{file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-win_amd64.whl", hash = "sha256:9bc90fb087bc7b4c15641521f31c0371e9a612fc2ba12c338d3ae032e6b6797f"},
]
[package.dependencies]
nvidia-nvjitlink-cu12 = "*"
[[package]]
name = "nvidia-cusparselt-cu12"
version = "0.6.2"
description = "NVIDIA cuSPARSELt"
optional = false
python-versions = "*"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:067a7f6d03ea0d4841c85f0c6f1991c5dda98211f6302cb83a4ab234ee95bef8"},
{file = "nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:df2c24502fd76ebafe7457dbc4716b2fec071aabaed4fb7691a201cde03704d9"},
{file = "nvidia_cusparselt_cu12-0.6.2-py3-none-win_amd64.whl", hash = "sha256:0057c91d230703924c0422feabe4ce768841f9b4b44d28586b6f6d2eb86fbe70"},
]
[[package]]
name = "nvidia-nccl-cu12"
version = "2.21.5"
description = "NVIDIA Collective Communication Library (NCCL) Runtime"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0"},
]
[[package]]
name = "nvidia-nvjitlink-cu12"
version = "12.4.127"
description = "Nvidia JIT LTO Library"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"},
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"},
]
[[package]]
name = "nvidia-nvtx-cu12"
version = "12.4.127"
description = "NVIDIA Tools Extension"
optional = false
python-versions = ">=3"
groups = ["main"]
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
files = [
{file = "nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7959ad635db13edf4fc65c06a6e9f9e55fc2f92596db928d169c0bb031e88ef3"},
{file = "nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a"},
{file = "nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485"},
]
[[package]]
name = "opentelemetry-api"
version = "1.32.0"
@ -2650,63 +2457,6 @@ files = [
{file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"},
]
[[package]]
name = "torch"
version = "2.6.0"
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
optional = false
python-versions = ">=3.9.0"
groups = ["main"]
files = [
{file = "torch-2.6.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:6860df13d9911ac158f4c44031609700e1eba07916fff62e21e6ffa0a9e01961"},
{file = "torch-2.6.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:c4f103a49830ce4c7561ef4434cc7926e5a5fe4e5eb100c19ab36ea1e2b634ab"},
{file = "torch-2.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:56eeaf2ecac90da5d9e35f7f35eb286da82673ec3c582e310a8d1631a1c02341"},
{file = "torch-2.6.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:09e06f9949e1a0518c5b09fe95295bc9661f219d9ecb6f9893e5123e10696628"},
{file = "torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:7979834102cd5b7a43cc64e87f2f3b14bd0e1458f06e9f88ffa386d07c7446e1"},
{file = "torch-2.6.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:ccbd0320411fe1a3b3fec7b4d3185aa7d0c52adac94480ab024b5c8f74a0bf1d"},
{file = "torch-2.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:46763dcb051180ce1ed23d1891d9b1598e07d051ce4c9d14307029809c4d64f7"},
{file = "torch-2.6.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:94fc63b3b4bedd327af588696559f68c264440e2503cc9e6954019473d74ae21"},
{file = "torch-2.6.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:2bb8987f3bb1ef2675897034402373ddfc8f5ef0e156e2d8cfc47cacafdda4a9"},
{file = "torch-2.6.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:b789069020c5588c70d5c2158ac0aa23fd24a028f34a8b4fcb8fcb4d7efcf5fb"},
{file = "torch-2.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:7e1448426d0ba3620408218b50aa6ada88aeae34f7a239ba5431f6c8774b1239"},
{file = "torch-2.6.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:9a610afe216a85a8b9bc9f8365ed561535c93e804c2a317ef7fabcc5deda0989"},
{file = "torch-2.6.0-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:4874a73507a300a5d089ceaff616a569e7bb7c613c56f37f63ec3ffac65259cf"},
{file = "torch-2.6.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:a0d5e1b9874c1a6c25556840ab8920569a7a4137afa8a63a32cee0bc7d89bd4b"},
{file = "torch-2.6.0-cp313-cp313-win_amd64.whl", hash = "sha256:510c73251bee9ba02ae1cb6c9d4ee0907b3ce6020e62784e2d7598e0cfa4d6cc"},
{file = "torch-2.6.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:ff96f4038f8af9f7ec4231710ed4549da1bdebad95923953a25045dcf6fd87e2"},
{file = "torch-2.6.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:9ea955317cfcd3852b1402b62af258ce735c2edeee42ca9419b6bc889e5ae053"},
{file = "torch-2.6.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:bb2c6c3e65049f081940f5ab15c9136c7de40d3f01192541c920a07c7c585b7e"},
{file = "torch-2.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:683410f97984103148e31b38a8631acf31c3034c020c0f4d26171e7626d8317a"},
{file = "torch-2.6.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:265f70de5fd45b864d924b64be1797f86e76c8e48a02c2a3a6fc7ec247d2226c"},
]
[package.dependencies]
filelock = "*"
fsspec = "*"
jinja2 = "*"
networkx = "*"
nvidia-cublas-cu12 = {version = "12.4.5.8", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-cuda-cupti-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-cuda-nvrtc-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-cuda-runtime-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-cudnn-cu12 = {version = "9.1.0.70", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-cufft-cu12 = {version = "11.2.1.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-curand-cu12 = {version = "10.3.5.147", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-cusolver-cu12 = {version = "11.6.1.9", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-cusparse-cu12 = {version = "12.3.1.170", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-cusparselt-cu12 = {version = "0.6.2", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-nccl-cu12 = {version = "2.21.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-nvjitlink-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-nvtx-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
setuptools = {version = "*", markers = "python_version >= \"3.12\""}
sympy = {version = "1.13.1", markers = "python_version >= \"3.9\""}
triton = {version = "3.2.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
typing-extensions = ">=4.10.0"
[package.extras]
opt-einsum = ["opt-einsum (>=3.3)"]
optree = ["optree (>=0.13.0)"]
[[package]]
name = "tqdm"
version = "4.67.1"

View File

@ -22,10 +22,9 @@ opentelemetry-instrumentation-grpc = "^0.53b0"
hf-transfer = "^0.1.9"
sentencepiece = "^0.2.0"
peft = "^0.15"
optimum-habana = "1.17"
transformers = "^4.49"
transformers = "^4.52.4"
numpy = "^1.26"
accelerate = "^0.33"
accelerate = "^1.7.0"
outlines= { version = "^0.0.36", optional = true }
prometheus-client = "^0.21.1"
py-cpuinfo = "^9.0.0"

View File

@ -1,4 +1,4 @@
accelerate==0.33.0 ; python_version >= "3.9" and python_version < "3.13"
accelerate==1.7.0 ; python_version >= "3.9" and python_version < "3.13"
annotated-types==0.7.0 ; python_version >= "3.9" and python_version < "3.13"
attrs==25.3.0 ; python_version >= "3.9" and python_version < "3.13"
certifi==2025.1.31 ; python_version >= "3.9" and python_version < "3.13"
@ -36,19 +36,6 @@ nest-asyncio==1.6.0 ; python_version >= "3.9" and python_version < "3.13"
networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
numba==0.60.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
nvidia-cublas-cu12==12.4.5.8 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cuda-cupti-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cuda-nvrtc-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cuda-runtime-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cudnn-cu12==9.1.0.70 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cufft-cu12==11.2.1.3 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-curand-cu12==10.3.5.147 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cusolver-cu12==11.6.1.9 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cusparse-cu12==12.3.1.170 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cusparselt-cu12==0.6.2 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-nccl-cu12==2.21.5 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-nvjitlink-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-nvtx-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
opentelemetry-api==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-common==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
@ -59,7 +46,6 @@ opentelemetry-instrumentation==0.53b0 ; python_version >= "3.9" and python_versi
opentelemetry-proto==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
optimum-habana==1.17.0 ; python_version >= "3.9" and python_version < "3.13"
optimum==1.24.0 ; python_version >= "3.9" and python_version < "3.13"
outlines==0.0.36 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.2 ; python_version >= "3.9" and python_version < "3.13"
@ -88,9 +74,8 @@ shellingham==1.5.4 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
threadpoolctl==3.6.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
torch==2.6.0 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.67.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.49.0 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.52.4 ; python_version >= "3.9" and python_version < "3.13"
triton==3.2.0 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
typer==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.13.2 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -1,6 +1,4 @@
import os
import psutil
import signal
import sys
import typer
@ -115,80 +113,19 @@ def serve(
raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)
logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype))
if sharded and os.getenv("ATTENTION", "default") not in {"paged"}:
tgi_file = Path(__file__).resolve().parent / "tgi_service.py"
num_shard = int(os.getenv("WORLD_SIZE", "1"))
logger.info("CLI SHARDED = {}".format(num_shard))
import subprocess
cmd = (
f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}"
)
cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}"
cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}"
cmd += f" --quantize {quantize} --max_input_tokens {max_input_tokens}"
if speculate is not None:
cmd += f"--speculate {speculate}"
logger.info("CLI server start deepspeed ={} ".format(cmd))
sys.stdout.flush()
sys.stderr.flush()
with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc:
do_terminate = False
current_handler = signal.getsignal(signal.SIGTERM)
def terminate_handler(sig, frame):
nonlocal do_terminate
do_terminate = True
if callable(current_handler):
current_handler(sig, frame)
signal.signal(signal.SIGTERM, terminate_handler)
finished = False
while not finished:
try:
if do_terminate:
parent = psutil.Process(proc.pid)
all_procs = parent.children(recursive=True) + [parent]
for p in all_procs:
try:
p.terminate()
except psutil.NoSuchProcess:
pass
_, alive = psutil.wait_procs(all_procs, timeout=30)
for p in alive:
p.kill()
do_terminate = False
proc.wait(timeout=3)
except subprocess.TimeoutExpired:
pass
else:
finished = True
sys.stdout.flush()
sys.stderr.flush()
if proc.returncode != 0:
logger.error(f"{cmd} exited with status = {proc.returncode}")
return proc.returncode
else:
server.serve(
model_id,
lora_adapters,
revision,
sharded,
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
uds_path,
max_input_tokens,
)
server.serve(
model_id,
lora_adapters,
revision,
sharded,
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
uds_path,
max_input_tokens,
)
@app.command()

View File

@ -1,53 +0,0 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import os
import habana_frameworks.torch as htorch
quant_config = os.getenv("QUANT_CONFIG", "")
is_quantization_enabled = quant_config != ""
if is_quantization_enabled:
os.environ.setdefault("ENABLE_EXPERIMENTAL_FLAGS", "true")
os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true")
os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false")
os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false")
os.environ.setdefault("UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av")
os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
def patch_scoped_linear_all_reduce(model):
from deepspeed.module_inject.layers import LinearAllreduce
from optimum.habana.transformers.models.modeling_all_models import (
ScopedLinearAllReduce,
)
for name, module in model.named_children():
if type(module) is LinearAllreduce:
SL = ScopedLinearAllReduce(mod=module)
setattr(model, name, SL)
patch_scoped_linear_all_reduce(module)
def setup_quantization(model):
if is_quantization_enabled:
htorch.core.quantization._mark_params_as_const(model)
htorch.core.quantization._check_params_as_const(model)
htorch.core.hpu_initialize(model)
return model
def prepare_model_for_quantization(model):
if is_quantization_enabled:
if model.config.model_type in [
"llama",
"falcon",
"qwen2",
"starcoder2",
"gemma",
]:
patch_scoped_linear_all_reduce(model)
from neural_compressor.torch.quantization import FP8Config, convert
config = FP8Config.from_json_file(quant_config)
model = convert(model, config)
return model

View File

@ -11,6 +11,7 @@ from .hpu import (
attention,
paged_attention,
paged_attention_mla,
set_block_mapping,
)
@ -22,6 +23,7 @@ __all__ = [
"get_kv_scales",
"paged_attention",
"paged_attention_mla",
"set_block_mapping",
"SUPPORTS_WINDOWING",
"KVCache",
"KVCompressCache",

View File

@ -8,6 +8,7 @@ from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA
import os
from text_generation_server.models.globals import BLOCK_SIZE
import math
SUPPORTS_WINDOWING = False
@ -106,6 +107,21 @@ def attention(
return attn_output
def set_block_mapping(hpu_attention_meta: HPUPagedAttentionMetadata, batch_size):
block_mapping = torch.nn.functional.one_hot(
hpu_attention_meta.block_groups, num_classes=batch_size
)
dtype = hpu_attention_meta.block_usage.dtype
device = hpu_attention_meta.block_usage.device
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
mask = mask >= hpu_attention_meta.block_usage.unsqueeze(-1)
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
hpu_attention_meta = hpu_attention_meta._replace(
attn_bias=attn_bias, block_mapping=block_mapping.to(dtype)
)
return hpu_attention_meta
def paged_attention(
query: torch.Tensor,
kv_cache: KVCache,
@ -176,4 +192,10 @@ def paged_attention_mla(
return output.view(batch_size, head_num, -1)
__all__ = ["SUPPORTS_WINDOWING", "attention", "paged_attention", "paged_attention_mla"]
__all__ = [
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"paged_attention_mla",
"set_block_mapping",
]

View File

@ -36,9 +36,7 @@ class PositionRotaryEmbedding(nn.Module):
self._sin_k_cached = None
self.scaling_factor = scaling_factor
self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, inv_freq.device, max_position_embeddings
)
self.max_position_embeddings = max_position_embeddings
def forward(
self,
@ -270,7 +268,9 @@ class PositionRotaryEmbedding(nn.Module):
self._sin_cached = torch.sin(freqs).to(dtype)
def get_cos_sin(self, position_ids: torch.Tensor):
self._update_cos_sin_cache(
torch.float32, position_ids.device, seqlen=self.max_position_embeddings
)
cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids)
@ -298,9 +298,6 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
self._cos_k_cached = None
self._sin_k_cached = None
self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, short_inv_freq.device, max_position_embeddings
)
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
@ -354,9 +351,6 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
self._cos_k_cached = None
self._sin_k_cached = None
self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, short_inv_freq.device, max_position_embeddings
)
def _update_cos_sin_cache(self, dtype, device, seqlen):
if (
@ -598,6 +592,9 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
position_ids: torch.Tensor,
):
slen = position_ids.shape[0]
self._update_cos_sin_cache(
torch.float32, position_ids.device, seqlen=self.max_position_embeddings
)
cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])

View File

@ -5,7 +5,6 @@ import os
from loguru import logger
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download, HfApi
from typing import Optional
from pathlib import Path
@ -36,14 +35,10 @@ __all__ = [
"Seq2SeqLM",
"get_model_with_lora_adapters",
]
from text_generation_server.models.globals import ATTENTION
VLM_BATCH_TYPES = set()
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
FLASH_ATTENTION = False
if ATTENTION == "paged":
FLASH_ATTENTION = True
FLASH_ATTENTION = True
try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM
@ -83,9 +78,6 @@ try:
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
FlashGPTNeoXForCausalLM,
)
from text_generation_server.models.pali_gemma import (
PaliGemmaBatch,
)
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration,
)
@ -156,7 +148,6 @@ if FLASH_ATTENTION:
)
VLM_BATCH_TYPES = {
PaliGemmaBatch,
FlashVlmCausalLMBatch,
FlashMllamaCausalLMBatch,
}
@ -642,6 +633,7 @@ def get_model(
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
support_chunking=False,
)
elif model_type == BAICHUAN:
return FlashCausalLM(
@ -791,6 +783,8 @@ def get_model(
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# TODO: Fix bug in rust image_text_replacement implementation
support_chunking=False,
)
elif model_type == QWEN2_5_VL:
return FlashVlmCausalLM(
@ -806,6 +800,8 @@ def get_model(
lora_adapter_ids=lora_adapter_ids,
config_class=Qwen2_5_VLConfig,
processor_class=Qwen2_5_VLProcessor,
# TODO: Fix bug in rust image_text_replacement implementation
support_chunking=False,
)
elif model_type == QWEN3:
return FlashCausalLM(
@ -843,6 +839,7 @@ def get_model(
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
support_chunking=False,
)
elif model_type == IDEFICS2:
return FlashVlmCausalLM(
@ -887,7 +884,6 @@ def get_model(
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
batch_class=PaliGemmaBatch,
)
elif model_type == LLAVA_NEXT:
return FlashVlmCausalLM(
@ -901,72 +897,6 @@ def get_model(
trust_remote_code=trust_remote_code,
)
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.custom_modeling.mllama import (
MllamaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
from text_generation_server.models.vlm_causal_lm import (
VlmCausalLMBatch,
)
VLM_BATCH_TYPES.add(VlmCausalLMBatch)
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
adapt_transformers_to_gaudi()
if SDP_ON_BF16 == 1:
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
if model_type == "gpt_bigcode":
from text_generation_server.models.starcoder import StarCoder
return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
if model_type == "bloom":
from text_generation_server.models.bloom import BLOOM
return BLOOM(
model_id=model_id,
revision=revision,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == "llava_next":
return VlmCausalLM(
model_class=LlavaNextForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=None,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == "mllama":
return VlmCausalLM(
model_class=MllamaForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=None,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise ValueError(f"Unsupported model type {model_type}")

View File

@ -1,52 +0,0 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import torch
from typing import Optional, Type
from transformers import PreTrainedTokenizerBase
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
class BloomCausalLMBatch(CausalLMBatch):
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super().from_pb(
pb=pb,
tokenizer=tokenizer,
dtype=dtype,
device=device,
)
batch.keys_head_dim_last = False
return batch
class BLOOM(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(BLOOM, self).__init__(
model_id=model_id,
revision=revision,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return BloomCausalLMBatch

View File

@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -415,6 +416,10 @@ class FlashCohereModel(torch.nn.Module):
seqlen: torch.Tensor,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward

View File

@ -26,6 +26,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -678,6 +679,10 @@ class DbrxModel(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward

View File

@ -33,6 +33,7 @@ from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
set_block_mapping,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
@ -569,6 +570,10 @@ class DeepseekV2Model(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward

View File

@ -34,6 +34,7 @@ from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention_mla,
set_block_mapping,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
@ -645,6 +646,10 @@ class DeepseekV3Model(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward

View File

@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -466,6 +467,10 @@ class FlashGemma2Model(torch.nn.Module):
adapter_data: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward

View File

@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -388,6 +389,10 @@ class FlashGemmaModel(torch.nn.Module):
adapter_data: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward

View File

@ -27,6 +27,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -383,6 +384,10 @@ class FlashGPT2Model(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
residual = None

View File

@ -28,6 +28,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -324,6 +325,10 @@ class FlashGPTJModel(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.wte(input_ids)
# Get rotary cos and sin for this forward

View File

@ -43,12 +43,12 @@ from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.attention import (
KVCache,
paged_attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaAttention,
LlamaMLP,
)
@ -444,7 +444,7 @@ class Llama4TextDecoderLayer(nn.Module):
if self.is_moe_layer: # the 128E model interleaves dense / sparse
self.feed_forward = Llama4TextMoe(f"{prefix}.feed_forward", config, weights)
else:
self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights)
self.feed_forward = Llama4TextMLP(f"{prefix}.feed_forward", config, weights)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm",
@ -549,6 +549,10 @@ class Llama4TextModel(nn.Module):
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
bs = seqlen.input_lengths.shape[0]
@ -1352,55 +1356,36 @@ class Llama4ForConditionalGeneration(nn.Module):
hidden_state = self.vision_model(pixel_values)
return hidden_state
def forward(
def get_vision_embeds(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=self.config.vision_config.vision_feature_layer,
vision_feature_select_strategy=self.config.vision_config.vision_feature_select_strategy,
image_sizes=image_sizes,
)
vision_flat = image_features.view(-1, image_features.size(-1))
image_features = self.multi_modal_projector(vision_flat)
return image_features
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask=None,
position_ids: Optional[torch.LongTensor] = None,
cu_seqlen_prefill: Optional[torch.Tensor] = None,
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = None,
slots: torch.Tensor = None,
seqlen: Seqlen = None,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
image_sizes: torch.Tensor = None,
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
**lm_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
def _get_padding_mask(input_ids, pad_token_id=0):
return (input_ids != pad_token_id).long()
attention_mask = _get_padding_mask(input_ids)
attention_mask = attention_mask.view(seqlen.input_lengths.shape[0], -1)
image_sizes: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.text_model.model.embed_tokens(input_ids)
vision_feature_layer = (
vision_feature_layer
if vision_feature_layer is not None
else self.config.vision_config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_config.vision_feature_select_strategy
)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes,
)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
original_inputs_embeds_shape = inputs_embeds.shape
vision_flat = image_features.view(-1, image_features.size(-1))
projected_vision_flat = self.multi_modal_projector(vision_flat)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
-1
)
@ -1410,19 +1395,33 @@ class Llama4ForConditionalGeneration(nn.Module):
final_mask_1d = final_mask[..., 0].reshape(-1)
num_tokens_to_fill = final_mask_1d.sum()
if num_tokens_to_fill != projected_vision_flat.size(0):
if num_tokens_to_fill != vision_embeds.size(0):
raise ValueError(
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
f"but multi_modal_projector returned {vision_embeds.size(0)}"
)
expanded_mask = final_mask_1d.unsqueeze(-1).expand(
-1, inputs_embeds.size(-1)
)
inputs_embeds = inputs_embeds.masked_scatter(
expanded_mask, projected_vision_flat
)
inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, vision_embeds)
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
cu_seqlen_prefill: Optional[torch.Tensor] = None,
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = None,
slots: torch.Tensor = None,
seqlen: Seqlen = None,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
lm_head_indices: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
**lm_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
logits, speculative_logits = self.text_model(
inputs_embeds,

View File

@ -35,6 +35,7 @@ from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoE
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -143,12 +144,14 @@ class FlashLlamaAttention(torch.nn.Module):
config.num_key_value_heads = getattr(
config, "num_key_value_heads", config.num_attention_heads
)
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
)
if config.model_type != "llama4_text":
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
)
# `config.attention_multiplier` is used in Granite
self.softmax_scale = getattr(
@ -547,6 +550,11 @@ class FlashLlamaModel(torch.nn.Module):
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
cross_attention_states=None,
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward

View File

@ -163,9 +163,114 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
)
return inputs_embeds
def forward(
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
# 1. Extract the input embeddings
# 2. Merge text and images
num_images, num_patches, channels, height, width = pixel_values.shape
pixel_values = pixel_values.view(
num_images * num_patches, channels, height, width
)
image_features = self.vision_tower(pixel_values)
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
# Already done within the clip model
selected_image_feature = image_features.last_hidden_state
if self.config.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.config.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise RuntimeError(
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
)
image_features = self.multi_modal_projector(selected_image_feature)
# split up image_features for each of the individual images
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [num_patches] * num_images
image_features = torch.split(image_features, split_sizes, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = (
self.config.vision_config.image_size // self.config.vision_config.patch_size
)
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
if height * width != base_image_feature.shape[0]:
raise ValueError(
"The number of patches is not consistent with the image size."
)
# Dimensions are intentionally swapped to be bug-compatible with
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature, self.image_newline[None]), dim=0
)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
return image_features.view(-1, image_features.shape[-1])
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, vision_embeds
)
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -173,101 +278,9 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused for this model
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
# 1. Extract the input embeddings
# 2. Merge text and images
num_images, num_patches, channels, height, width = pixel_values.shape
pixel_values = pixel_values.view(
num_images * num_patches, channels, height, width
)
image_features = self.vision_tower(pixel_values)
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
# Already done within the clip model
selected_image_feature = image_features.last_hidden_state
if self.config.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.config.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise RuntimeError(
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
)
image_features = self.multi_modal_projector(selected_image_feature)
# split up image_features for each of the individual images
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [num_patches] * num_images
image_features = torch.split(image_features, split_sizes, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
if height * width != base_image_feature.shape[0]:
raise ValueError(
"The number of patches is not consistent with the image size."
)
# Dimensions are intentionally swapped to be bug-compatible with
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature, self.image_newline[None]), dim=0
)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_features
)
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,

View File

@ -30,6 +30,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -396,6 +397,10 @@ class MistralModel(torch.nn.Module):
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
adapter_data: Optional[torch.Tensor] = None,
):
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer

View File

@ -37,6 +37,7 @@ from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
set_block_mapping,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
@ -446,6 +447,10 @@ class MixtralModel(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
@ -505,7 +510,6 @@ class FlashMixtralForCausalLM(torch.nn.Module):
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
position_ids,

View File

@ -38,6 +38,7 @@ from text_generation_server.models.custom_modeling.flash_llama_modeling import (
)
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA
import habana_frameworks.torch as htorch
def _prepare_aspect_ratio_attention_mask(
@ -236,10 +237,19 @@ class MllamaVisionSdpaAttention(nn.Module):
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn_output = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
attn_output = fsdpa_op(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
scale=None,
softmax_mode="None",
recompute_mode=None,
valid_sequence_lengths=None,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
@ -320,6 +330,9 @@ class MllamaVisionEncoder(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
):
encoder_states = [hidden_states]
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for encoder_layer in self.layers:
layer_outputs = encoder_layer(
hidden_states,
@ -328,6 +341,8 @@ class MllamaVisionEncoder(nn.Module):
hidden_states = layer_outputs
encoder_states.append(hidden_states)
if lazy_mode:
htorch.core.mark_step()
return hidden_states, encoder_states
@ -699,8 +714,6 @@ class MllamaTextCrossAttention(nn.Module):
# key_states = key_states.repeat(1, self.num_key_value_groups, 1)
# value_states = value_states.repeat(1, self.num_key_value_groups, 1)
causal = False
# logger.info(
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
# )
@ -715,7 +728,7 @@ class MllamaTextCrossAttention(nn.Module):
value_states,
attn_mask=None,
dropout_p=0.0,
is_causal=causal,
is_causal=False,
scale=None,
softmax_mode="None",
recompute_mode=None,

View File

@ -29,6 +29,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -354,6 +355,10 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_in(input_ids)
# Get rotary cos and sin for this forward

View File

@ -62,10 +62,40 @@ class PaliGemmaForConditionalGeneration(nn.Module):
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
self.dtype = weights.dtype
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
pixel_values = pixel_values.to(dtype=self.dtype)
image_outputs = self.vision_tower(pixel_values)
last_hidden_state = self.post_vision_tower_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multi_modal_projector(last_hidden_state)
image_features = image_features.view(-1, image_features.shape[-1])
return image_features
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
mask = input_ids == self.config.image_token_index
inputs_embeds[mask] = vision_embeds
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -73,32 +103,13 @@ class PaliGemmaForConditionalGeneration(nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused here
pixel_attention_mask: Optional[torch.BoolTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids)
# TODO This is odd but apparently pali gemma position ids start at 1.
if cu_seqlen_prefill is not None:
position_ids += 1
if pixel_values is not None:
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
image_outputs = self.vision_tower(pixel_values)
last_hidden_state = self.post_vision_tower_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multi_modal_projector(last_hidden_state)
# mask where image or padding tokens
mask = input_ids == self.config.image_token_index
# insert image features into input embeddings
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,

View File

@ -9,6 +9,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -347,6 +348,10 @@ class FlashPhiModel(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward

View File

@ -8,6 +8,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -288,6 +289,10 @@ class Qwen2Model(torch.nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
@ -359,7 +364,6 @@ class Qwen2ForCausalLM(torch.nn.Module):
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(

View File

@ -18,6 +18,7 @@ import habana_frameworks.torch as htorch
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -266,7 +267,10 @@ class Qwen3Model(nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
@ -334,7 +338,6 @@ class Qwen3ForCausalLM(nn.Module):
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = self.model(

View File

@ -18,6 +18,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.attention import (
attention,
paged_attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -628,6 +629,10 @@ class FlashRWModel(FlashRWPreTrainedModel):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.word_embeddings(input_ids)
# Get rotary cos and sin for this forward

View File

@ -8,6 +8,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -437,6 +438,10 @@ class FlashSantacoderModel(nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
if self.process_group.size() > 1:

View File

@ -29,6 +29,7 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -511,6 +512,10 @@ class Starcoder2Model(torch.nn.Module):
adapter_data,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
@ -584,7 +589,6 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
position_ids,

View File

@ -734,9 +734,107 @@ class Idefics2ForConditionalGeneration(nn.Module):
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds
def forward(
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
assert pixel_values is not None
batch_size, num_images, num_channels, height, width = pixel_values.shape
all_states = []
all_pixel_values = pixel_values
all_pixel_mask = pixel_attention_mask
for i in range(batch_size):
pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility
pixel_values = pixel_values[i : i + 1]
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(
dim=(-1, -2, -3)
) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(
pixel_values.size(0),
pixel_values.size(2),
pixel_values.size(3),
),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask/pP p
pixel_attention_mask = all_pixel_mask[i : i + 1]
pixel_attention_mask = pixel_attention_mask.view(
1 * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[
real_images_inds
].contiguous()
patch_size = self.config.vision_config.patch_size
"""
patches_subgrid = pixel_attention_mask.unfold(
dimension=1, size=patch_size, step=patch_size
)
patches_subgrid = patches_subgrid.unfold(
dimension=2, size=patch_size, step=patch_size
)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
"""
# hpu does none support unfold
conv_kernel = torch.ones(
[1, 1, patch_size, patch_size],
dtype=pixel_values.dtype,
device=pixel_values.device,
)
patches_subgrid = torch.nn.functional.conv2d(
pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),
conv_kernel,
stride=patch_size,
).squeeze(1)
patch_attention_mask = torch.gt(patches_subgrid, 0)
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
)
# Modality projection & resampling
image_hidden_states = self.connector(
image_hidden_states,
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
)
all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0)
return image_hidden_states.view(-1, image_hidden_states.shape[-1])
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, vision_embeds
)
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -744,98 +842,9 @@ class Idefics2ForConditionalGeneration(nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
# Unused here
image_sizes: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None:
batch_size, num_images, num_channels, height, width = pixel_values.shape
all_states = []
all_pixel_values = pixel_values
all_pixel_mask = pixel_attention_mask
for i in range(batch_size):
pixel_values = all_pixel_values.to(
dtype=self.dtype
) # fp16 compatibility
pixel_values = pixel_values[i : i + 1]
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(
dim=(-1, -2, -3)
) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(
pixel_values.size(0),
pixel_values.size(2),
pixel_values.size(3),
),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask/pP p
pixel_attention_mask = all_pixel_mask[i : i + 1]
pixel_attention_mask = pixel_attention_mask.view(
1 * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[
real_images_inds
].contiguous()
patch_size = self.config.vision_config.patch_size
"""
patches_subgrid = pixel_attention_mask.unfold(
dimension=1, size=patch_size, step=patch_size
)
patches_subgrid = patches_subgrid.unfold(
dimension=2, size=patch_size, step=patch_size
)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
"""
# hpu does none support unfold
conv_kernel = torch.ones(
[1, 1, patch_size, patch_size],
dtype=pixel_values.dtype,
device=pixel_values.device,
)
patches_subgrid = torch.nn.functional.conv2d(
pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),
conv_kernel,
stride=patch_size,
).squeeze(1)
patch_attention_mask = torch.eq(
patches_subgrid, (patch_size * patch_size)
)
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
)
# Modality projection & resampling
image_hidden_states = self.connector(
image_hidden_states,
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
)
all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0)
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states
)
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,

View File

@ -477,9 +477,107 @@ class Idefics3ForConditionalGeneration(nn.Module):
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds
def forward(
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
batch_size, num_images, num_channels, height, width = pixel_values.shape
all_states = []
all_pixel_values = pixel_values
all_pixel_mask = pixel_attention_mask
for i in range(batch_size):
pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility
pixel_values = pixel_values[i : i + 1]
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(
dim=(-1, -2, -3)
) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(
pixel_values.size(0),
pixel_values.size(2),
pixel_values.size(3),
),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask/pP p
pixel_attention_mask = all_pixel_mask[i : i + 1]
pixel_attention_mask = pixel_attention_mask.view(
1 * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[
real_images_inds
].contiguous()
patch_size = self.config.vision_config.patch_size
"""
patches_subgrid = pixel_attention_mask.unfold(
dimension=1, size=patch_size, step=patch_size
)
patches_subgrid = patches_subgrid.unfold(
dimension=2, size=patch_size, step=patch_size
)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
"""
# hpu does none support unfold
conv_kernel = torch.ones(
[1, 1, patch_size, patch_size],
dtype=pixel_values.dtype,
device=pixel_values.device,
)
patches_subgrid = torch.nn.functional.conv2d(
pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),
conv_kernel,
stride=patch_size,
).squeeze(1)
patch_attention_mask = torch.gt(patches_subgrid, 0)
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
)
# Modality projection & resampling
image_hidden_states = self.connector(
image_hidden_states,
)
all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0)
return image_hidden_states.view(-1, image_hidden_states.shape[-1])
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, vision_embeds
)
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -487,99 +585,10 @@ class Idefics3ForConditionalGeneration(nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
# Unused here
image_sizes: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None:
batch_size, num_images, num_channels, height, width = pixel_values.shape
all_states = []
all_pixel_values = pixel_values
all_pixel_mask = pixel_attention_mask
for i in range(batch_size):
pixel_values = all_pixel_values.to(
dtype=self.dtype
) # fp16 compatibility
pixel_values = pixel_values[i : i + 1]
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(
dim=(-1, -2, -3)
) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(
pixel_values.size(0),
pixel_values.size(2),
pixel_values.size(3),
),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask/pP p
pixel_attention_mask = all_pixel_mask[i : i + 1]
pixel_attention_mask = pixel_attention_mask.view(
1 * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[
real_images_inds
].contiguous()
patch_size = self.config.vision_config.patch_size
"""
patches_subgrid = pixel_attention_mask.unfold(
dimension=1, size=patch_size, step=patch_size
)
patches_subgrid = patches_subgrid.unfold(
dimension=2, size=patch_size, step=patch_size
)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
"""
# hpu does none support unfold
conv_kernel = torch.ones(
[1, 1, patch_size, patch_size],
dtype=pixel_values.dtype,
device=pixel_values.device,
)
patches_subgrid = torch.nn.functional.conv2d(
pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),
conv_kernel,
stride=patch_size,
).squeeze(1)
patch_attention_mask = torch.eq(
patches_subgrid, (patch_size * patch_size)
)
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
)
# Modality projection & resampling
image_hidden_states = self.connector(
image_hidden_states,
)
all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states
)
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,

View File

@ -1,467 +0,0 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Llava-NeXT model."""
from typing import List, Optional, Union
import torch
import torch.utils.checkpoint
import numpy as np
from loguru import logger
from transformers.models.llava_next.modeling_llava_next import (
unpad_image,
)
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
from transformers.image_processing_utils import select_best_resolution
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (`tuple`):
The size of the input image in the format (width, height).
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists")
height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
"""
Calculate the number of patches after the preprocessing for images of any resolution.
Args:
image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`):
The size of the input image in the format (height, width). ?
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
int: the number of patches
"""
if not isinstance(grid_pinpoints, list):
raise TypeError("grid_pinpoints should be a list of tuples or lists")
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
if not isinstance(image_size, (list, tuple)):
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
raise TypeError(
f"image_size invalid type {type(image_size)} with value {image_size}"
)
image_size = image_size.tolist()
best_resolution = select_best_resolution(image_size, grid_pinpoints)
height, width = best_resolution
num_patches = 0
# consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
num_patches += 1
# add the base patch
num_patches += 1
return num_patches
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[int] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = True,
flash_attention_recompute: Optional[bool] = True,
):
if token_idx is not None:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
token_idx=token_idx,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)
logits = outputs[0]
if not return_dict:
output = (logits,) + outputs[1:]
return output
return outputs
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411
def pack_image_features(
self,
image_features,
image_sizes,
vision_feature_select_strategy,
image_newline=None,
):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
Args:
image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
List of image feature tensor, each contains all the visual feature of all patches.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_select_strategy (`str`)
The feature selection strategy used to select the vision feature from the vision backbone.
image_newline (`torch.Tensor` of shape `(embed_dim)`)
New line embedding vector.
Returns:
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
feature_lens (`List[int]`)
token length of each image in image_features
"""
new_image_features = []
feature_lens = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
if (
np.prod(image_feature.shape)
% (num_patch_height * num_patch_width * height * width)
!= 0
and vision_feature_select_strategy == "default"
):
logger.warning_once(
"Image feature shape does not line up with the provided patch size. "
"You may be using the `default` vision_feature_select_strategy with a"
" visual encoder that does not have CLS."
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
if image_newline is not None:
image_feature = torch.cat(
(
image_feature,
image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device, image_feature.dtype),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
if image_newline is not None:
image_feature = torch.cat(
(image_feature, image_newline[None].to(image_feature)), dim=0
)
new_image_features.append(image_feature)
feature_lens.append(image_feature.size(0))
image_features = torch.cat(new_image_features, dim=0)
feature_lens = torch.tensor(
feature_lens, dtype=torch.long, device=image_features.device
)
return image_features, feature_lens
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479
def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
"""
# ! infer image_num_patches from image_sizes
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.config.image_grid_pinpoints,
patch_size=self.config.vision_config.image_size,
)
for imsize in image_sizes
]
if pixel_values.dim() == 5:
# stacked if input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [
pix_val[:num_patch]
for pix_val, num_patch in zip(pixel_values, image_num_patches)
]
pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError(
f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions"
)
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_image_feature = image_features.hidden_states[vision_feature_layer]
else:
hs_pool = [
image_features.hidden_states[layer_idx]
for layer_idx in vision_feature_layer
]
selected_image_feature = torch.cat(hs_pool, dim=-1)
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
image_features = self.multi_modal_projector(selected_image_feature)
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
pixel_values=None,
image_sizes=None,
attention_mask=None,
**kwargs,
):
"""
Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635
The only differences are:
- add new args token_idx
- add the process of merging images into inputs_embeds
"""
token_idx = kwargs.get("token_idx", None)
if token_idx is None:
return super().prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
pixel_values=pixel_values,
image_sizes=image_sizes,
attention_mask=attention_mask,
**kwargs,
)
else:
use_flash_attention = kwargs.get("use_flash_attention", True)
flash_attention_recompute = kwargs.get("flash_attention_recompute", True)
position_ids = kwargs.get("position_ids", None)
labels = kwargs.get("labels", None)
if (
past_key_values is None
and pixel_values is not None
and input_ids.shape[1] != 1
):
vision_feature_select_strategy = kwargs.get(
"vision_feature_select_strategy", None
)
vision_feature_layer = kwargs.get("vision_feature_layer", None)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
vision_feature_layer = (
vision_feature_layer
if vision_feature_layer is not None
else self.config.vision_feature_layer
)
# 1. Extract the input embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text and images
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=self.image_newline,
)
special_image_mask = (
input_ids == self.config.image_token_index
).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
if inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_index).sum()
n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(
special_image_mask, image_features
)
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
# generation with cache
elif past_key_values is not None:
seq_len = input_ids.shape[1]
pad_len = seq_len - token_idx.item()
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(
first_layer_past_key_value.float().sum(-2) == 0
)
# Get the target length
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = extended_attention_mask
attention_mask[:, -pad_len:] = 0
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
if token_idx is not None:
position_ids = (
torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
)
else:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"token_idx": token_idx,
"labels": labels,
"use_flash_attention": use_flash_attention,
"flash_attention_recompute": flash_attention_recompute,
}
)
return model_inputs

View File

@ -1,292 +0,0 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Mllama model."""
from typing import Optional, Tuple, List, Union
import torch
import torch.utils.checkpoint
from optimum.habana.transformers.models import GaudiMllamaForConditionalGeneration
from optimum.habana.transformers.models.mllama.modeling_mllama import (
_prepare_cross_attention_mask,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
class MllamaForConditionalGeneration(GaudiMllamaForConditionalGeneration):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
aspect_ratio_mask: Optional[torch.Tensor] = None,
aspect_ratio_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = True,
flash_attention_recompute: Optional[bool] = True,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Copied from MllamaForConditionalGeneration::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2077
The only differences are:
- add token_idx input
- add use_flash_attention and flash_attention_recompute
"""
full_text_row_masked_out_mask = kwargs.get(
"full_text_row_masked_out_mask", None
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
past_key_values=past_key_values,
use_cache=use_cache,
inputs_embeds=inputs_embeds,
labels=labels,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
token_idx=token_idx,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)
logits = outputs[0]
if not return_dict:
output = (logits,) + outputs[1:]
return output
return outputs
def prepare_inputs_for_generation(
self,
input_ids=None,
inputs_embeds=None,
attention_mask=None,
position_ids=None,
pixel_values=None,
aspect_ratio_ids=None,
aspect_ratio_mask=None,
cross_attention_mask=None,
past_key_values=None,
use_cache=False,
cache_position=None,
num_logits_to_keep=None,
**kwargs,
):
"""
Copied from MllamaForConditionalGeneration::prepare_inputs_for_generation: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2208
The only differences are:
- add token_idx handling
- add bucket_internal handling
- add use_flash_attention and flash_attention_recompute
"""
token_idx = kwargs.get("token_idx", None)
if token_idx is None:
return super().prepare_inputs_for_generation(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
pixel_values=pixel_values,
aspect_ratio_ids=aspect_ratio_ids,
aspect_ratio_mask=aspect_ratio_mask,
cross_attention_mask=cross_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
else:
use_flash_attention = kwargs.get("use_flash_attention", True)
flash_attention_recompute = kwargs.get("flash_attention_recompute", True)
position_ids = kwargs.get("position_ids", None)
output_attentions = kwargs.get("output_attentions", None)
output_hidden_states = kwargs.get("output_hidden_states", None)
return_dict = kwargs.get("return_dict", None)
labels = kwargs.get("labels", None)
cross_attention_states = kwargs.get("cross_attention_states", None)
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
bucket_internal = kwargs.get("bucket_internal", None)
if past_key_values is not None:
if token_idx is not None:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
elif inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif (
input_ids.shape[1] != cache_position.shape[0]
): # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
elif bucket_internal and token_idx is not None:
# for the 1st token we can slice the inputs till token idx for the fwd pass.
input_ids = input_ids[:, :token_idx]
attention_mask = attention_mask[:, :token_idx]
if cross_attention_mask is not None:
cross_attention_mask = cross_attention_mask[:, :token_idx, ...]
# TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
if token_idx is not None:
position_ids = torch.index_select(
position_ids, 1, token_idx - 1
)
else:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(
memory_format=torch.contiguous_format
)
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None and cross_attention_states is not None:
raise ValueError(
"`pixel_values` and `cross_attention_states` cannot be provided simultaneously"
)
if pixel_values is not None:
if aspect_ratio_ids is None:
raise ValueError(
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
)
# get vision tokens from vision model
vision_outputs = self.vision_model(
pixel_values=pixel_values,
aspect_ratio_ids=aspect_ratio_ids,
aspect_ratio_mask=aspect_ratio_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
use_flash_attention=use_flash_attention,
)
cross_attention_states = vision_outputs[0]
cross_attention_states = self.multi_modal_projector(
cross_attention_states
).reshape(-1, cross_attention_states.shape[-2], self.hidden_size)
if cross_attention_mask is not None:
cross_attention_mask, full_text_row_masked_out_mask = (
_prepare_cross_attention_mask(
cross_attention_mask,
num_vision_tokens=self.vision_model.num_patches,
dtype=self.dtype,
token_idx=token_idx,
)
)
else:
full_text_row_masked_out_mask = None
if cross_attention_mask is not None:
if cache_position is not None:
cross_attention_mask = cross_attention_mask[:, :, cache_position]
full_text_row_masked_out_mask = full_text_row_masked_out_mask[
:, :, cache_position
]
elif past_key_values is not None:
if token_idx is not None:
cross_attention_mask = torch.index_select(
cross_attention_mask, -2, token_idx - 1
)
full_text_row_masked_out_mask = torch.index_select(
full_text_row_masked_out_mask, -2, token_idx - 1
)
else:
cross_attention_mask = cross_attention_mask[:, :, -1:]
full_text_row_masked_out_mask = full_text_row_masked_out_mask[
:, :, -1:
]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# The clone here is for the same reason as for `position_ids`.
model_inputs = {
"input_ids": input_ids.clone(memory_format=torch.contiguous_format),
"inputs_embeds": None,
}
if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep
# keep cache_position implementation as None for HPU
cache_position = None
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"token_idx": token_idx,
"labels": labels,
"return_dict": kwargs.get("return_dict"),
"full_text_row_masked_out_mask": full_text_row_masked_out_mask,
"use_flash_attention": use_flash_attention,
"cross_attention_mask": cross_attention_mask,
"cross_attention_states": cross_attention_states,
"output_attentions": output_attentions,
"flash_attention_recompute": flash_attention_recompute,
}
)
return model_inputs

View File

@ -45,11 +45,17 @@ from text_generation_server.layers.attention import (
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2Model,
)
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
apply_rotary_pos_emb,
)
import habana_frameworks.torch as htorch
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py
from typing import Union
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput, VideoInput
from transformers.image_utils import ImageInput
from transformers.video_utils import VideoInput
from transformers.processing_utils import (
ProcessingKwargs,
ProcessorMixin,
@ -375,28 +381,6 @@ class Qwen2_5_VLConfig(PretrainedConfig):
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
tensor: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
orig_dtype = tensor.dtype
tensor = tensor.float()
cos = freqs.cos()
sin = freqs.sin()
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
output = (tensor * cos) + (rotate_half(tensor) * sin)
output = output.to(orig_dtype)
return output
class Qwen2_5VLAttention(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
@ -426,7 +410,8 @@ class Qwen2_5VLAttention(nn.Module):
self,
hidden_state: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
max_seqlen: int,
) -> torch.Tensor:
# apply the qkv linear layer to the hidden state
@ -444,29 +429,37 @@ class Qwen2_5VLAttention(nn.Module):
query = query.view(*_shape)
key = key.view(*_shape)
value = value.view(*_shape)
# apply rotary positional embeddings
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
0
)
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
rotary_dim = cos.shape[-1]
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape))
# calc maximum sequence length for any batch
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
causal = False
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
# execute sdpa
query = query.unsqueeze(0).transpose(1, 2)
key = key.unsqueeze(0).transpose(1, 2)
value = value.unsqueeze(0).transpose(1, 2)
causal = False
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
attention_mask = torch.zeros(
[1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool
)
for i in range(1, len(cu_seqlens)):
attention_mask[
:, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]
] = True
attn_output = fsdpa_op(
query,
key,
value,
attn_mask=None,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=causal,
scale=None,
@ -474,7 +467,7 @@ class Qwen2_5VLAttention(nn.Module):
recompute_mode=None,
valid_sequence_lengths=None,
)
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
attn_output = attn_output.transpose(0, 1)
# reshape output to original dimensions
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
@ -533,11 +526,9 @@ class Qwen2_5VLVisionBlock(nn.Module):
weights=weights,
)
def forward(
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
) -> torch.Tensor:
def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor:
norm1_out, _ = self.norm1(hidden_states)
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen)
hidden_states = hidden_states + attn_out
norm2_out, _ = self.norm2(hidden_states)
mlp_out = self.mlp(norm2_out)
@ -608,7 +599,7 @@ class Qwen2_5VisionModel(nn.Module):
config=config,
weights=weights,
)
# import ipdb; ipdb.set_trace()
self.temporal_patch_size = config.temporal_patch_size
self.spatial_patch_size = config.spatial_patch_size
self.in_channels = config.in_channels
@ -736,6 +727,10 @@ class Qwen2_5VisionModel(nn.Module):
)
rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device)
cos = rotary_pos_emb.cos()
sin = rotary_pos_emb.sin()
cos = torch.cat((cos, cos), dim=-1).unsqueeze(1)
sin = torch.cat((sin, sin), dim=-1).unsqueeze(1)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
@ -754,6 +749,9 @@ class Qwen2_5VisionModel(nn.Module):
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
# iterately apply the blocks to the hidden states
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for layer_num, block in enumerate(self.blocks):
# NOTE: qwen2_5_vl.py has a concept of full attention blocks
# that are applied at specific layers.
@ -762,9 +760,9 @@ class Qwen2_5VisionModel(nn.Module):
else:
cu_seqlens_now = cu_window_seqlens
hidden_states = block(
hidden_states, cu_seqlens_now, rotary_pos_emb, max_seqlen
)
hidden_states = block(hidden_states, cu_seqlens_now, cos, sin, max_seqlen)
if lazy_mode:
htorch.core.mark_step()
# apply the final patch merger to the hidden states
hidden_states = self.merger(hidden_states)
@ -886,9 +884,6 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
full_llm_pos_ids_list = [
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
]
# import ipdb
# ipdb.set_trace()
max_s = full_llm_pos_ids_list[-1].max() + 1
final_text_len = input_ids_len - vision_ends[-1]
if final_text_len > 0:
@ -900,9 +895,33 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
)
return position_ids
def forward(
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
return image_embeds
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if vision_embeds is not None:
mask = torch.where(input_ids == self.image_token_id)
inputs_embeds[mask] = vision_embeds
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -910,26 +929,10 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor],
pixel_values: torch.FloatTensor = None,
image_grid_thw: Optional[torch.LongTensor] = None,
# Unused in this model
video_grid_thw: Optional[torch.LongTensor] = None,
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0:
if pixel_values is not None:
image_embeds = self.visual(
pixel_values, grid_thw=image_grid_thw
).squeeze(0)
mask = torch.where(input_ids == self.image_token_id)
inputs_embeds[mask] = image_embeds
hidden_states = self.text_model(
inputs_embeds=inputs_embeds,

View File

@ -44,28 +44,11 @@ from text_generation_server.layers.attention import (
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2Model,
)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
tensor: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
orig_dtype = tensor.dtype
tensor = tensor.float()
cos = freqs.cos()
sin = freqs.sin()
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
output = (tensor * cos) + (rotate_half(tensor) * sin)
output = output.to(orig_dtype)
return output
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
apply_rotary_pos_emb,
)
import habana_frameworks.torch as htorch
class Qwen2VLAttention(nn.Module):
@ -96,7 +79,8 @@ class Qwen2VLAttention(nn.Module):
self,
hidden_state: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
max_seqlen: int,
) -> torch.Tensor:
# apply the qkv linear layer to the hidden state
@ -116,27 +100,36 @@ class Qwen2VLAttention(nn.Module):
value = value.view(*_shape)
# apply rotary positional embeddings
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
0
)
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
rotary_dim = cos.shape[-1]
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape))
# calc maximum sequence length for any batch
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
causal = False
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
# execute sdpa
query = query.unsqueeze(0).transpose(1, 2)
key = key.unsqueeze(0).transpose(1, 2)
value = value.unsqueeze(0).transpose(1, 2)
causal = False
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
attention_mask = torch.zeros(
[1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool
)
for i in range(1, len(cu_seqlens)):
attention_mask[
:, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]
] = True
attn_output = fsdpa_op(
query,
key,
value,
attn_mask=None,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=causal,
scale=None,
@ -144,7 +137,7 @@ class Qwen2VLAttention(nn.Module):
recompute_mode=None,
valid_sequence_lengths=None,
)
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
attn_output = attn_output.transpose(0, 1)
# reshape output to original dimensions
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
attn_output = self.proj(attn_output)
@ -193,11 +186,9 @@ class Qwen2VLVisionBlock(nn.Module):
weights=weights,
)
def forward(
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
) -> torch.Tensor:
def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor:
norm1_out, residual = self.norm1(hidden_states)
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen)
hidden_states = attn_out + residual
norm2_out, residual = self.norm2(hidden_states)
hidden_states = hidden_states + self.mlp(norm2_out)
@ -330,6 +321,11 @@ class Qwen2VisionModel(nn.Module):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype)
cos = rotary_pos_emb.cos()
sin = rotary_pos_emb.sin()
cos = torch.cat((cos, cos), dim=-1).unsqueeze(1)
sin = torch.cat((sin, sin), dim=-1).unsqueeze(1)
# create a cu_seqlens tensor to be used in the attention mask
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
@ -337,8 +333,13 @@ class Qwen2VisionModel(nn.Module):
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
# iterately apply the blocks to the hidden states
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for block in self.blocks:
hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen)
hidden_states = block(hidden_states, cu_seqlens, cos, sin, max_seqlen)
if lazy_mode:
htorch.core.mark_step()
# apply the final patch merger to the hidden states
hidden_states = self.merger(hidden_states)
@ -474,9 +475,33 @@ class Qwen2VLForConditionalGeneration(nn.Module):
)
return position_ids
def forward(
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
return image_embeds
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if vision_embeds is not None:
mask = torch.where(input_ids == self.image_token_id)
inputs_embeds[mask] = vision_embeds
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -484,26 +509,10 @@ class Qwen2VLForConditionalGeneration(nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor],
pixel_values: torch.FloatTensor = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0:
if pixel_values is not None:
image_embeds = self.visual(
pixel_values, grid_thw=image_grid_thw
).squeeze(0)
mask = torch.where(input_ids == self.image_token_id)
inputs_embeds[mask] = image_embeds
hidden_states = self.text_model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,

View File

@ -153,19 +153,14 @@ def prepare_for_decode(
block_list_device = _async_h2d_tensor_copy(block_list)
block_groups_device = _async_h2d_tensor_copy(block_groups)
block_usage_device = _async_h2d_tensor_copy(block_usage)
block_mapping = torch.nn.functional.one_hot(
block_groups_device, num_classes=batch_size
)
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
mask = mask >= block_usage_device.unsqueeze(-1)
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
return trim_attn_metadata(
HPUPagedAttentionMetadata(
block_list=block_list_device,
block_groups=block_groups_device,
block_usage=block_usage_device,
block_mapping=block_mapping.to(dtype),
attn_bias=attn_bias,
block_mapping=None,
attn_bias=None,
)
)
@ -428,10 +423,8 @@ class FlashCausalLMBatch(Batch):
for i, input_ids in enumerate(all_input_ids):
all_input_ids_tensor[i, : len(input_ids)] = input_ids
# Create tensors on device
all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device
)
# put on cpu temporarily, move to hpu in prepare_for_prefill
all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64)
top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64)
@ -701,7 +694,9 @@ class FlashCausalLMBatch(Batch):
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
def concatenate(
cls, batches: List["FlashCausalLMBatch"], padded_total_bs: int = 0
) -> "FlashCausalLMBatch":
# Batch attributes
requests = []
requests_idx_mapping = {}
@ -750,7 +745,10 @@ class FlashCausalLMBatch(Batch):
adapter_meta = None
adapter_segment_builder = None
else:
input_ids = batches[0].input_ids.new_empty(total_batch_size)
if padded_total_bs == batches[0].input_ids.shape[0]:
input_ids = batches[0].input_ids
else:
input_ids = batches[0].input_ids.new_empty(total_batch_size)
if (
batches[0].position_ids is not None
and batches[0].position_ids.dim() == 2
@ -784,9 +782,7 @@ class FlashCausalLMBatch(Batch):
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
(total_batch_size, max_blocks)
)
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
(total_batch_size, max_length)
)
all_input_ids_tensor = batches[0].all_input_ids_tensor
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
@ -829,9 +825,12 @@ class FlashCausalLMBatch(Batch):
index = torch.tensor(list(range(start_index, end_index)), device="cpu")
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor[:valid_bsize, :max_length]
if i > 0:
all_input_ids_tensor.index_copy_(
0,
index.to(batch.all_input_ids_tensor.device),
batch.all_input_ids_tensor[:valid_bsize, :],
)
block_tables_tensor[
start_index:end_index, : batch.block_tables_tensor.shape[1]
@ -851,9 +850,10 @@ class FlashCausalLMBatch(Batch):
)
if not prefilling:
input_ids.index_copy_(
0, index.to(input_ids.device), batch.input_ids[:valid_bsize]
)
if padded_total_bs != batches[0].input_ids.shape[0] or i > 0:
input_ids.index_copy_(
0, index.to(input_ids.device), batch.input_ids[:valid_bsize]
)
position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize])
slot_indices.index_copy_(
0, index, batch.slot_indices + cumulative_slots
@ -987,7 +987,6 @@ class FlashCausalLMBatch(Batch):
else:
padded_bs = self.input_ids.shape[0]
slots = self.slots[self.slot_indices]
extra_pad = padded_bs - self.input_ids.shape[0]
self.hpu_attn_meta = prepare_for_decode(
dtype,
@ -998,17 +997,29 @@ class FlashCausalLMBatch(Batch):
padded_bs,
bucketing_ctx,
)
self.input_ids = F.pad(self.input_ids, (0, extra_pad), value=0)
self.position_ids = F.pad(self.position_ids, (0, extra_pad), value=1)
self.input_ids = F.pad(
self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=0
)
if self.position_ids.dim() == 2:
# Qwen VL case
self.position_ids = F.pad(
self.position_ids,
(0, 0, 0, padded_bs - self.position_ids.shape[0]),
value=1,
)
else:
self.position_ids = F.pad(
self.position_ids, (0, padded_bs - self.position_ids.shape[0]), value=1
)
self.input_lengths_tensor = F.pad(
self.input_lengths_tensor, (0, extra_pad), value=0
self.input_lengths_tensor,
(0, padded_bs - self.input_lengths_tensor.shape[0]),
value=0,
)
self.cache_lengths_tensor = F.pad(
self.cache_lengths_tensor, (0, extra_pad), value=0
)
self.all_input_ids_tensor = F.pad(
self.all_input_ids_tensor,
(0, 0, 0, extra_pad),
self.cache_lengths_tensor,
(0, padded_bs - self.cache_lengths_tensor.shape[0]),
value=0,
)
next_token_chooser_parameters = []
@ -1028,7 +1039,9 @@ class FlashCausalLMBatch(Batch):
fsm_grammar_states,
)
def prepare_for_prefill(self, max_padded_input_len, max_padded_bs):
def prepare_for_prefill(
self, max_padded_input_len, max_padded_bs, max_total_tokens
):
# Prepare values if we need to continue prefilling
# Speculation must be ignored while we prefill even with chunking
# it simplifies everything
@ -1044,7 +1057,7 @@ class FlashCausalLMBatch(Batch):
# need extra pad to match warmup seq
extra_pad = max_padded_input_len - self.max_input_length
extra_pad_bs = max_padded_bs - len(self)
device = self.all_input_ids_tensor.device
device = "hpu"
if isinstance(self.input_ids, list) and len(self) > 1:
input_ids_padded_length = []
input_ids = []
@ -1062,8 +1075,19 @@ class FlashCausalLMBatch(Batch):
input_ids = [0] * extra_pad + input_ids
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
else:
self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0)
input_ids_padded_length.extend([extra_pad] * len(self))
input_ids = self.input_ids.new_zeros(max_padded_input_len * len(self))
src_pos = 0
for i in range(len(self)):
end_pos = (i + 1) * max_padded_input_len
start_pos = end_pos - self.input_lengths[i]
input_ids[start_pos:end_pos] = self.input_ids[
src_pos : src_pos + self.input_lengths[i]
]
input_ids_padded_length.append(
max_padded_input_len - self.input_lengths[i]
)
src_pos += self.input_lengths[i]
self.input_ids = input_ids
self.input_ids = F.pad(
self.input_ids, (0, extra_pad_bs * max_padded_input_len), value=0
@ -1288,12 +1312,17 @@ class FlashCausalLMBatch(Batch):
self.prefill_next_token_indices = (
self.prefill_next_token_indices + input_ids_padded_length_tensor
)
self.all_input_ids_tensor = F.pad(
self.all_input_ids_tensor,
(0, 0, 0, extra_pad_bs),
value=0,
all_input_ids_tensor = torch.zeros(
(max_padded_bs, max(max_total_tokens, self.all_input_ids_tensor.shape[-1])),
dtype=torch.int64,
device="hpu",
)
for i in range(len(self)):
all_input_ids_tensor[i, : self.all_input_ids_tensor.shape[-1]] = (
self.all_input_ids_tensor[i]
)
self.all_input_ids_tensor = all_input_ids_tensor
next_token_chooser_parameters = []
next_token_chooser_parameters.extend([r.parameters for r in self.requests])
pad_next_token_chooser_parameters(next_token_chooser_parameters, max_padded_bs)
@ -1448,7 +1477,7 @@ class FlashCausalLM(Model):
if head_size is None:
# Some models use GQA and different sizes for o_proj
# and q_proj, that allows for that.
if hasattr(config, "head_dim"):
if getattr(config, "head_dim", None) is not None:
self.head_size = config.head_dim
else:
self.head_size = config.hidden_size // config.num_attention_heads
@ -1459,6 +1488,8 @@ class FlashCausalLM(Model):
self.kv_cache = []
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
self.bucketing_ctx = None
self.max_total_tokens = None
self.max_input_tokens = None
htorch.core.hpu_set_env()
if htorch.utils.internal.is_lazy():
htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True)
@ -1564,6 +1595,14 @@ class FlashCausalLM(Model):
logger.info,
f"Free memory on device {self.device}: {format_bytes(free_memory)} used_for_graph: {format_bytes(mem_used_from_graph)} ratio {graph_reserved_mem} reserved_for_runtime: {format_bytes(self.mem_reserved)}",
)
if max_total_tokens is None:
max_total_tokens = sum(batch.input_lengths)
if max_input_tokens is None:
max_input_tokens = max_total_tokens - 1
self.max_total_tokens = max_total_tokens
self.max_input_tokens = max_input_tokens
try:
self.init_kv_cache(
batch.num_blocks,
@ -1597,11 +1636,6 @@ class FlashCausalLM(Model):
)
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
if max_total_tokens is None:
max_total_tokens = sum(batch.input_lengths)
if max_input_tokens is None:
max_input_tokens = max_total_tokens - 1
self.kv_cache = []
empty_cache()
@ -2017,7 +2051,9 @@ class FlashCausalLM(Model):
accepted_ids,
speculative_ids,
) = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_current_length],
batch.all_input_ids_tensor[
: batch.next_token_logits.shape[0], : batch.max_current_length
],
batch.next_token_logits,
speculate,
batch.speculative_ids,
@ -2031,14 +2067,29 @@ class FlashCausalLM(Model):
accepted_ids,
)
if batch.valid_indices is not None:
next_token_logprobs = next_token_logprobs.cpu()
accepted_ids = accepted_ids.cpu()
batch.all_input_ids_tensor = batch.all_input_ids_tensor[
batch.valid_indices
]
next_input_ids = next_input_ids[batch.valid_indices]
next_token_logprobs = next_token_logprobs[batch.valid_indices]
accepted_ids = accepted_ids[batch.valid_indices]
# TODO speculative decoding handling missing
index = torch.arange(
0,
len(batch.valid_indices),
device=batch.all_input_ids_tensor.device,
)
batch.all_input_ids_tensor.index_copy_(
0, index, batch.all_input_ids_tensor[batch.valid_indices]
)
padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size(
len(batch.valid_indices)
)
next_input_ids.index_copy_(
0, index, next_input_ids[batch.valid_indices]
)
next_input_ids = next_input_ids[:padded_total_bs]
next_token_logprobs.index_copy_(
0, index, next_token_logprobs[batch.valid_indices]
)
accepted_ids.index_copy_(
0, index, accepted_ids[batch.valid_indices]
)
if speculative_ids is not None:
speculative_ids = speculative_ids[batch.valid_indices]
batch.top_n_tokens_tensor = batch.top_n_tokens_tensor[
@ -2106,10 +2157,13 @@ class FlashCausalLM(Model):
batch.slot_indices += accepted_ids[: len(batch)]
else:
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
index = F.pad(
index, (0, next_input_ids.shape[0] - index.shape[0]), value=0
)
index = index.to(batch.all_input_ids_tensor.device)
batch_idx = torch.arange(
0,
batch.all_input_ids_tensor.shape[0],
index.shape[0],
dtype=torch.long,
device=batch.all_input_ids_tensor.device,
)
@ -2197,7 +2251,18 @@ class FlashCausalLM(Model):
htorch.core.mark_step()
# Stage 2. Prepare new batch for speculative scheduling
if len(batches) > 1:
batch = self.batch_type.concatenate(batches)
if self.bucketing_ctx is not None:
total_batch_size = 0
for b in batches:
total_batch_size += len(b)
padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size(
total_batch_size
)
batch = self.batch_type.concatenate(
batches, padded_total_bs=padded_total_bs
)
else:
batch = self.batch_type.concatenate(batches)
else:
batch = batches[0]
prefill = batch.prefilling
@ -2208,13 +2273,18 @@ class FlashCausalLM(Model):
batch.max_input_length
),
self.bucketing_ctx.get_padded_prompt_batch_size(len(batch)),
self.max_total_tokens,
)
else:
batch.prepare_for_prefill(batch.max_input_length, len(batch))
batch.prepare_for_prefill(
batch.max_input_length, len(batch), self.max_total_tokens
)
else:
batch.prepare_for_decode(
self.dtype, self.use_contiguous_pa, self.bucketing_ctx
)
if hasattr(self, "set_inputs_embeds") and callable(self.set_inputs_embeds):
self.set_inputs_embeds(batch)
prefill_logprobs = batch.prefill_next_token_indices is not None
# Update adapter indices for speculative tokens (if present)
adapter_meta = batch.adapter_meta

View File

@ -1,7 +1,7 @@
import torch
from PIL import Image
from io import BytesIO
from dataclasses import dataclass
from opentelemetry import trace
from typing import Iterable, Optional, Tuple, List, Type, Dict
@ -119,17 +119,17 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return height // patch_size, width // patch_size
def image_text_replacement(processor, image_input, config, image_id: int) -> str:
def image_text_replacement(processor, image_input, config) -> str:
if config.model_type == "idefics2":
image_seq_len = 64
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
if processor.image_processor.do_image_splitting:
image_str *= 5
return image_str
return image_str, IDEFICS2_FAKE_TOKEN
if config.model_type == "idefics3":
# TODO: implement this in a more general way
n_rows = image_input["rows"][0][image_id]
n_cols = image_input["cols"][0][image_id]
n_rows = image_input["rows"][0][0]
n_cols = image_input["cols"][0][0]
image_seq_len = int(
((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
/ (config.scale_factor**2)
@ -142,41 +142,41 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
image_token=IDEFICS3_IMAGE_TOKEN,
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
)
return image_str
return image_str, IDEFICS3_FAKE_IMAGE_TOKEN
elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id]
height, width = image_input["image_sizes"][0]
num_features = get_number_of_features(height, width, config)
log_master(
logger.info,
f"Found {num_features} features in image of resolution {height}x{width}",
)
return "<image>" * num_features
return "<image>" * num_features, "<image>"
elif config.model_type == "paligemma":
return "<image>" * config.text_config.num_image_tokens
return "<image>" * config.text_config.num_image_tokens, "<image>"
elif config.model_type == "qwen2_vl":
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
grid_t, grid_h, grid_w = image_input["image_grid_thw"][0]
num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>"
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
elif config.model_type == "qwen2_5_vl":
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
grid_t, grid_h, grid_w = image_input["image_grid_thw"][0]
num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>"
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
elif config.model_type == "gemma3":
# TODO: get correct number of features via reviewing the Gemma3 architecture
# and calculating the number of image tokens
num_pads = 256
padding = "<image_soft_token>" * num_pads
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n"
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n", "<start_of_image>"
elif config.model_type == "llama4":
patch_size = config.vision_config.patch_size
pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio
downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
aspect_ratios = image_input["aspect_ratios"][image_id]
image_height, image_width = image_input["pixel_values"][image_id].shape[-2:]
aspect_ratios = image_input["aspect_ratios"][0]
image_height, image_width = image_input["pixel_values"][0].shape[-2:]
num_patches_per_chunk = int(
(image_height // patch_size)
@ -187,7 +187,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
aspect_ratios, num_patches_per_chunk
)
return tokens_for_this_image
return tokens_for_this_image, "<|image_start|>"
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
@ -200,6 +200,27 @@ def image_text_replacement_fixup(config, text: str) -> str:
return text
def preprocess_text(config, text: str) -> str:
if config.model_type == "paligemma":
return "<bos>" + text + "\n"
return text
def preprocess_image(config, img):
model_type = config.model_type
if model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20:
img = img.resize((img.width * 2, img.height * 2))
if model_type == "paligemma":
img = img.convert("RGB")
if model_type not in {"llava_next", "gemma3", "llama4"}:
# TODO: check if this is needed
img = [img]
return img
def get_unpadded_features(
original_height: int,
original_width: int,
@ -254,105 +275,259 @@ def get_number_of_features(height: int, width: int, config) -> int:
return unpadded_features + newline_features + base_features
def scatter_image_embeds(
embeds: torch.Tensor, is_embed: Optional[torch.Tensor]
) -> torch.Tensor:
if is_embed is None:
return embeds
placeholders = embeds.new_full(
(is_embed.shape[0], embeds.shape[-1]),
fill_value=torch.nan,
)
placeholders[is_embed.to(embeds.device)] = embeds
return placeholders
def gather_image_embeds(
embeds: torch.Tensor, is_embed: Optional[torch.Tensor]
) -> Optional[torch.Tensor]:
if is_embed is None:
return embeds
sel = embeds[is_embed.to(embeds.device)]
return sel if sel.numel() else None
@dataclass
class ImagePositions:
offset: int
length: int
id: int
num_placeholder_tokens: int
is_embed: Optional[torch.Tensor] = None
class FlashVlmCausalLMBatch(FlashCausalLMBatch):
image_inputs: Optional[List[List[Dict[str, torch.Tensor]]]]
image_positions: Optional[List[List[ImagePositions]]]
encoder_cache: Optional[List[Dict[int, torch.Tensor]]]
pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
image_grid_thw: Optional[torch.Tensor]
cache_entries_to_free: List[Tuple[int, int]]
has_image_inputs: bool = False
inputs_embeds: Optional[torch.Tensor] = None
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches):
batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches)
def concatenate(cls, batches, padded_total_bs: int = 0):
batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs)
batch.image_inputs = []
batch.image_positions = []
batch.encoder_cache = []
for b in batches:
if b.image_inputs is not None:
batch.image_inputs.extend(b.image_inputs)
else:
batch.image_inputs.append(None)
if b.image_positions is not None:
batch.image_positions.extend(b.image_positions)
else:
batch.image_positions.append(None)
if b.encoder_cache is not None:
batch.encoder_cache.extend(b.encoder_cache)
else:
batch.encoder_cache.append(None)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
batch.inputs_embeds = None
# To be filled in prepare_for_prefill
batch.has_image_inputs = False
batch.cache_entries_to_free = []
return batch
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
image_inputs = []
image_positions = []
encoder_cache = []
for request_id in request_ids:
idx = self.requests_idx_mapping[request_id]
image_inputs.append(self.image_inputs[idx])
image_positions.append(self.image_positions[idx])
encoder_cache.append(self.encoder_cache[idx])
batch = super().filter(request_ids)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
batch.inputs_embeds = None
batch.image_inputs = image_inputs
batch.image_positions = image_positions
batch.encoder_cache = encoder_cache
# To be filled in prepare_for_prefill
batch.has_image_inputs = False
batch.cache_entries_to_free = []
return batch
@classmethod
def batch_tokenized_inputs(
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
):
# Process images first. We need all of them so that the processor
# can make the image splits the same size. And we need the final
# sizes to insert correct number of image tokens.
images = []
kwargs = {}
if (
hasattr(processor, "image_processor_class")
and processor.image_processor_class == "Idefics3ImageProcessor"
):
kwargs["return_row_col_info"] = True
max_length = 0
vocab = tokenizer.get_vocab()
if not hasattr(config, "image_token_index"):
config.image_token_index = config.image_token_id
batch_tokenized_inputs: List[List[int]] = []
batch_image_inputs: List[Optional[List[dict]]] = []
batch_image_positions: List[Optional[List[ImagePositions]]] = []
for r in requests:
text_parts = []
image_inputs = []
image_texts = []
image_id = 0
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
pass
text = preprocess_text(config, chunk.text)
text_parts.append(text)
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
# qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the
# default warmup image is 20x20
if config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
if image.width <= 20:
w = image.width * 2
h = image.height * 2
image = image.resize((w, h))
img = Image.open(BytesIO(chunk.image.data))
img = preprocess_image(config, img)
if config.model_type == "llava_next":
images.append(image)
elif config.model_type == "gemma3":
images.append(image)
elif config.model_type == "llama4":
images.append(image)
else:
images.append([image])
image_input = processor.image_processor(
[img], return_tensors="pt", **kwargs
)
image_inputs.append(image_input)
img_text, img_start_token_str = image_text_replacement(
processor, image_input, config
)
text_parts.append(img_text)
image_texts.append([image_id, img_start_token_str, img_text])
image_id += 1
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
if images:
kwargs = {}
if (
hasattr(processor, "image_processor_class")
and processor.image_processor_class == "Idefics3ImageProcessor"
):
kwargs["return_row_col_info"] = True
image_inputs = processor.image_processor(
images, return_tensors="pt", **kwargs
)
else:
image_inputs = None
batch_tokenized_inputs = []
max_length = 0
image_id = 0
for r in requests:
full_text = ""
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
full_text += chunk.text
elif chunk_type == "image":
full_text += image_text_replacement(
processor, image_inputs, config, image_id
)
image_id += 1
full_text = image_text_replacement_fixup(config, full_text)
full_text = image_text_replacement_fixup(config, "".join(text_parts))
input_ids = tokenizer(
full_text,
truncation=True,
max_length=r.truncate,
add_special_tokens=r.add_special_tokens,
add_special_tokens=(
r.add_special_tokens if config.model_type != "paligemma" else False
),
)["input_ids"]
max_length = max(max_length, len(input_ids))
batch_tokenized_inputs.append(input_ids)
return batch_tokenized_inputs, image_inputs
if len(image_inputs) > 0:
img_start_token = vocab[image_texts[0][1]]
image_positions = cls.get_image_positions(
input_ids, image_texts, img_start_token, config, tokenizer
)
else:
image_inputs = None
image_positions = None
batch_tokenized_inputs.append(input_ids)
batch_image_inputs.append(image_inputs)
batch_image_positions.append(image_positions)
return batch_tokenized_inputs, batch_image_inputs, batch_image_positions
@classmethod
def get_image_positions(
cls,
input_ids: List[int],
image_texts: List[Tuple[int, str, str]],
img_start_token: int,
config,
tokenizer: PreTrainedTokenizerBase,
) -> List[ImagePositions]:
image_positions = []
num_images = len(image_texts)
input_ids_t = torch.as_tensor(input_ids)
img_start_token_pos = torch.where(input_ids_t.eq(img_start_token))[0]
num_tokens = input_ids_t.numel()
last_pos = 0
for i in range(num_images):
image_id, img_start_token_str, img_text = image_texts[i]
img_text = image_text_replacement_fixup(config, img_text)
if config.model_type == "gemma3":
img_text = img_text.replace("\n\n", "")
tokens = tokenizer(img_text, add_special_tokens=False, return_tensors="pt")[
"input_ids"
][0]
length = tokens.numel()
assert (
length <= num_tokens
), f"{length} > {num_tokens} Image is truncated, try increasing --max-batch-prefill-tokens"
pos = torch.searchsorted(img_start_token_pos, last_pos, right=False)
index = img_start_token_pos[pos]
assert torch.equal(
input_ids_t[index : index + length], tokens
), "Image tokens not found in input_ids"
is_embed = tokens == config.image_token_index
num_placeholder_tokens = int(is_embed.sum())
if num_placeholder_tokens == length:
is_embed = None
pos = ImagePositions(
offset=index,
length=length,
id=image_id,
num_placeholder_tokens=num_placeholder_tokens,
is_embed=is_embed,
)
image_positions.append(pos)
last_pos = index + length
if (
config.model_type == "idefics2"
and i + 1 != num_images
and input_ids[last_pos] == config.image_token_index
):
fake_token = last_pos - 1
fake_token_index = torch.searchsorted(
img_start_token_pos, fake_token, right=False
)
img_start_token_pos[fake_token_index] = last_pos
image_texts[i + 1][2] = image_texts[i + 1][2][
len(img_start_token_str) :
]
return image_positions
@classmethod
def from_pb_processor(
@ -364,33 +539,164 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch):
dtype: torch.dtype,
device: torch.device,
) -> "FlashVlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor, config
batch_tokenized_inputs, image_inputs, image_positions = (
cls.batch_tokenized_inputs(pb.requests, tokenizer, processor, config)
)
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
if "pixel_attention_mask" in image_inputs:
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
device=device
)
else:
batch.pixel_attention_mask = None
if "image_sizes" in image_inputs:
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
else:
batch.image_sizes = None
if "image_grid_thw" in image_inputs:
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
else:
batch.image_grid_thw = None
else:
batch.image_inputs = image_inputs
batch.image_positions = image_positions
batch.encoder_cache = [{} for _ in range(len(pb.requests))]
if len(image_inputs):
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
return batch
def prepare_for_prefill(
self, max_padded_input_len, max_padded_bs, max_total_tokens
):
super().prepare_for_prefill(
max_padded_input_len, max_padded_bs, max_total_tokens
)
self.has_image_inputs = False
self.cache_entries_to_free = []
self.pixel_values = []
assert (
len(self.cache_lengths)
== len(self.input_lengths)
== len(self.prefilling_mask)
), "Mismatch in lengths of cache_lengths, input_lengths, and prefilling_mask"
for i, (
cache_length,
input_length,
request_prefilling,
) in enumerate(
zip(
self.cache_lengths,
self.input_lengths,
self.prefilling_mask,
)
):
if not request_prefilling or self.image_positions[i] is None:
continue
for image_position in self.image_positions[i]:
if image_position is None:
continue
start_pos = image_position.offset
length = image_position.length
if start_pos >= cache_length + input_length:
# No encoder input required at this step
break
if start_pos + length <= cache_length:
# The encode input is already processed
continue
self.has_image_inputs = True
if image_position.id not in self.encoder_cache[i]:
image_inputs = self.image_inputs[i][image_position.id]
self.pixel_values.append((i, image_position.id, image_inputs))
# Remove the image from the image_inputs
self.image_inputs[i][image_position.id] = None
if not self.has_image_inputs:
self.pixel_values = None
self.pixel_attention_mask = None
self.image_sizes = None
self.image_grid_thw = None
else:
image_grid_thw_list = [
x[2]["image_grid_thw"]
for x in self.pixel_values
if "image_grid_thw" in x[2]
]
if image_grid_thw_list:
self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0)
else:
self.image_grid_thw = None
def update_encoder_cache(self, encoder_outputs, request_id, img_pos):
self.encoder_cache[request_id][img_pos.id] = scatter_image_embeds(
encoder_outputs, img_pos.is_embed
)
def gather_vision_embeds(self):
device = self.input_ids.device
chunks = []
for (
i,
cache_length,
input_length,
request_prefilling,
) in zip(
range(len(self.requests)),
self.cache_lengths,
self.input_lengths,
self.prefilling_mask,
):
if not request_prefilling or self.image_positions[i] is None:
continue
for image_position in self.image_positions[i]:
if image_position is None:
continue
start_pos = image_position.offset
length = image_position.length
if start_pos >= cache_length + input_length:
# No encoder input required at this step
break
if start_pos + length <= cache_length:
# The encode input is already processed
continue
start_idx = max(cache_length - start_pos, 0)
end_idx = min(cache_length - start_pos + input_length, length)
assert (
image_position.id in self.encoder_cache[i]
), f"image_id {image_position.id} not in encoder_cache {self.encoder_cache[i]}"
encoder_output = self.encoder_cache[i][image_position.id]
is_embed = image_position.is_embed
if is_embed is not None:
is_embed = is_embed[start_idx:end_idx]
from loguru import logger
logger.info(
f"image_id {image_position.id} start_idx {start_idx} end_idx {end_idx}, length {length}"
)
embeds = gather_image_embeds(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
if embeds is not None:
chunks.append(embeds)
if end_idx == length:
self.cache_entries_to_free.append((i, image_position.id))
self.image_positions[i][image_position.id] = None
if len(chunks) == 0:
return None
return torch.cat(chunks, dim=0).to(device)
def free_encoder_cache(self):
for i, image_id in self.cache_entries_to_free:
self.encoder_cache[i].pop(image_id, None)
self.cache_entries_to_free = []
class FlashVlmCausalLM(FlashCausalLM):
def __init__(
@ -402,6 +708,7 @@ class FlashVlmCausalLM(FlashCausalLM):
batch_class=FlashVlmCausalLMBatch,
revision,
trust_remote_code: bool,
support_chunking: bool = False,
**kwargs,
):
if PREFIX_CACHING:
@ -419,8 +726,7 @@ class FlashVlmCausalLM(FlashCausalLM):
model_id=model_id,
revision=revision,
trust_remote_code=trust_remote_code,
# FIXME: VLM do not work with context chunking yet
support_chunking=False,
support_chunking=support_chunking,
**kwargs,
)
@ -471,9 +777,12 @@ class FlashVlmCausalLM(FlashCausalLM):
bucketing_ctx=None,
)
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
inputs_embeds = self.get_inputs_embeds(
input_ids=input_ids.to(self.device),
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids),
inputs_embeds=inputs_embeds,
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
@ -481,10 +790,7 @@ class FlashVlmCausalLM(FlashCausalLM):
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=hpu_attention_meta,
lm_head_indices=None,
pixel_values=None,
pixel_attention_mask=None,
image_sizes=None,
image_grid_thw=None,
attention_mask=None,
)
def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
@ -546,6 +852,84 @@ class FlashVlmCausalLM(FlashCausalLM):
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
)
def get_vision_embeds(
self,
pixel_values: torch.Tensor,
pixel_attention_mask: torch.Tensor,
image_sizes: torch.Tensor,
image_grid_thw: torch.Tensor,
):
embeds = self.model.get_vision_embeds(
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
image_sizes=image_sizes,
image_grid_thw=image_grid_thw,
)
return embeds
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: Optional[torch.Tensor] = None,
):
return self.model.get_inputs_embeds(
input_ids=input_ids,
vision_embeds=vision_embeds,
)
def encode_images(self, batch):
if batch.pixel_values is not None:
device = batch.input_ids.device
for request_id, image_id, image_input in batch.pixel_values:
pixel_values = image_input["pixel_values"].to(device)
if "pixel_attention_mask" in image_input:
pixel_attention_mask = image_input["pixel_attention_mask"].to(
device
)
else:
pixel_attention_mask = None
if "image_sizes" in image_input:
image_sizes = image_input["image_sizes"].to(device)
else:
image_sizes = None
if "image_grid_thw" in image_input:
image_grid_thw = image_input["image_grid_thw"]
else:
image_grid_thw = None
encoder_outputs = self.get_vision_embeds(
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
image_sizes=image_sizes,
image_grid_thw=image_grid_thw,
)
batch.update_encoder_cache(
encoder_outputs,
request_id,
batch.image_positions[request_id][image_id],
)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
def set_inputs_embeds(self, batch):
if batch.has_image_inputs:
self.encode_images(batch)
vision_embeds = batch.gather_vision_embeds()
batch.has_image_inputs = False
else:
vision_embeds = None
inputs_embeds = self.get_inputs_embeds(
batch.input_ids, vision_embeds=vision_embeds
)
batch.inputs_embeds = inputs_embeds
def forward(
self,
batch: FlashVlmCausalLMBatch,
@ -593,6 +977,7 @@ class FlashVlmCausalLM(FlashCausalLM):
position_ids = new_position_ids
else:
input_ids = batch.input_ids
inputs_embeds = batch.inputs_embeds
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
@ -605,10 +990,25 @@ class FlashVlmCausalLM(FlashCausalLM):
if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
if position_ids.dim() == 1 and batch.prefilling:
position_ids = self.model.get_position_ids(
input_ids, batch.image_grid_thw
input_ids.cpu(), batch.image_grid_thw
)
batch.position_ids = position_ids
attention_mask = None
attention_mask_forward = None
if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None:
attention_mask = self.model.get_attention_mask(
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
)
min_dtype = torch.finfo(self.dtype).min
attention_mask_forward = torch.where(attention_mask, 0, min_dtype).to(
input_ids.device
)
attention_mask = attention_mask.reshape(-1)
if self.model.config.model_type == "llama4":
attention_mask = (input_ids != 0).long()
attention_mask_forward = attention_mask.view(input_lengths.shape[0], -1)
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
@ -639,7 +1039,7 @@ class FlashVlmCausalLM(FlashCausalLM):
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=kv_cache,
@ -647,18 +1047,11 @@ class FlashVlmCausalLM(FlashCausalLM):
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=batch.hpu_attn_meta,
lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes,
image_grid_thw=batch.image_grid_thw,
attention_mask=attention_mask_forward,
**kwargs,
)
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
if batch.image_grid_thw is not None:
batch.image_grid_thw = None
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
batch.image_grid_thw = None
batch.free_encoder_cache()
return logits, speculative_logits

View File

@ -1,156 +0,0 @@
import re
import torch
import torch.distributed
from transformers import (
PreTrainedTokenizerBase,
)
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
NextTokenChooser,
StoppingCriteria,
)
from text_generation_server.utils.chunks import concat_text_chunks
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
# we split individual characters inside special tokens like [START_DNA]
CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")
# token added to implement a custom sequence tokenization. This token is added at
# corpus cleaning step and removed in pretokenization. The digits are added to increase the chance
# that they do not occur in the corpus. The digits are escaped so that the token does not appear
# literally in the source code in case we ever include it in the training data.
SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"
def _insert_split_marker(m: re.Match):
"""
Applies split marker based on a regex match of special tokens such as
[START_DNA].
Parameters
----------
n : str
Input text to split
Returns
----------
str - the text with the split token added
"""
start_token, _, sequence, end_token = m.groups()
sequence = re.sub(r"(.)", rf"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"
def escape_custom_split_sequence(text):
"""
Applies custom splitting to the text for GALILEO's tokenization
Parameters
----------
text : str
Input text to split
Returns
----------
str - the text with the split token added
"""
return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)
# END CREDIT
class GalacticaCausalLMBatch(CausalLMBatch):
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "GalacticaCausalLMBatch":
inputs = []
next_token_choosers = []
stopping_criterias = []
prefix_offsets = []
top_n_tokens = []
read_offsets = []
requests_idx_mapping = {}
# Parse batch
max_truncation = 0
padding_right_offset = 0
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(
escape_custom_split_sequence(concat_text_chunks(r.input_chunks.chunks))
)
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
tokenized_inputs = tokenizer(
inputs,
return_tensors="pt",
padding=True,
return_token_type_ids=False,
truncation=True,
max_length=max_truncation,
).to(device)
for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(0)
read_offsets.append(input_len)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
input_ids = tokenized_inputs["input_ids"]
# Allocate maximum attention_mask
attention_mask = input_ids.new_zeros(
(pb.size, max_input_length + padding_right_offset)
)
# Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * max_input_length + max_decode_tokens
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=None,
all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(),
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
)

View File

@ -4,14 +4,14 @@ from loguru import logger
from text_generation_server.utils.log import log_master
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
ATTENTION = os.getenv("ATTENTION", "default")
ATTENTION = os.getenv("ATTENTION", "paged")
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in {
"1",
"true",
}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
_expected = {"paged", "default"}
_expected = {"paged"}
assert (
ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}"

View File

@ -1,882 +0,0 @@
from io import BytesIO
from PIL import Image
import torch
import time
from dataclasses import dataclass
from opentelemetry import trace
from transformers import (
AutoConfig,
AutoProcessor,
AutoTokenizer,
PreTrainedTokenizerBase,
ProcessorMixin,
)
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model
from text_generation_server.models.types import (
Batch,
Tokens,
Generation,
GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
import torch.distributed
from text_generation_server.models.custom_modeling.idefics_modeling import (
IdeficsForVisionText2Text,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.quantization import get_loader
tracer = trace.get_tracer(__name__)
@dataclass
class IdeficsCausalLMBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
requests_idx_mapping: Dict[int, int]
# Decoder values
input_ids: torch.Tensor
attention_mask: torch.Tensor
position_ids: torch.Tensor
pixel_values: Optional[torch.Tensor]
image_hidden_states: Optional[torch.Tensor]
image_attention_mask: Optional[torch.Tensor]
past_key_values: Optional[List[Tuple]]
# All tokens
all_input_ids: List[torch.Tensor]
# Lengths of all generations present in the batch
input_lengths: List[int]
prefix_offsets: List[int]
read_offsets: List[int]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
# Metadata used for padding
max_input_length: int
padding_right_offset: int
# Maximum number of tokens this batch will grow to
max_tokens: int
# Past metadata
keys_head_dim_last: bool = True
def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch(
id=self.batch_id,
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
)
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "IdeficsCausalLMBatch":
raise NotImplementedError
@classmethod
def from_pb_processor(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
processor: ProcessorMixin, # Hack
config,
dtype: torch.dtype,
device: torch.device,
) -> "IdeficsCausalLMBatch":
inputs = []
next_token_choosers = []
stopping_criterias = []
prefix_offsets = []
read_offsets = []
requests_idx_mapping = {}
# Parse batch
max_truncation = 0
padding_right_offset = 0
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.input_chunks.chunks)
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
# TODO Check impact on idefics
prompts = []
for inp in inputs:
# Each input is encoded into a list, where each element of this input list is either a string or a URL
prompt = []
for chunk in inp:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
prompt.append(chunk.text)
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
prompt.append(image)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
prompts.append(prompt)
# The processor replaces the call to tokenizer, and
# a/ takes care of fetching images from the URL
# b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model
tokenized_inputs = processor(
prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_truncation,
# TODO Check impact on idefics
# add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
).to(device)
for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(
input_len - 5
) # To decode without potential fallbacks errors
read_offsets.append(
input_len
) # To decode without potential fallbacks errors
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
input_ids = tokenized_inputs["input_ids"]
pixel_values = tokenized_inputs.get("pixel_values", None)
image_hidden_states = None
# Allocate maximum attention_mask
attention_mask = input_ids.new_zeros(
(pb.size, max_input_length + padding_right_offset)
)
# Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
# Do the same for image_attention_mask
if pixel_values is None:
image_attention_mask = None
else:
image_attention_mask = input_ids.new_zeros(
(
pb.size,
max_input_length + padding_right_offset,
pixel_values.size(1),
)
)
image_attention_mask[:, :max_input_length, :] = tokenized_inputs[
"image_attention_mask"
]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(
1, dim=1
) # It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
pixel_values=pixel_values,
image_hidden_states=image_hidden_states,
image_attention_mask=image_attention_mask,
past_key_values=None,
all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(),
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
)
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]:
# It deletes requests from the batch. For instance when client lost connection
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
if len(request_ids) == len(self):
return self
keep_indices = []
# New values after filtering
requests_idx_mapping = {}
requests = []
input_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
max_input_length = 0
next_token_choosers = []
stopping_criterias = []
total_remaining_decode_tokens = 0
new_padding_right_offset = 0
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
requests_idx_mapping[request_id] = i
keep_indices.append(idx)
requests.append(self.requests[idx])
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
all_input_ids.append(self.all_input_ids[idx])
request_input_length = self.input_lengths[idx]
input_lengths.append(request_input_length)
max_input_length = max(max_input_length, request_input_length)
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
total_remaining_decode_tokens += remaining_decode_tokens
new_padding_right_offset = max(
new_padding_right_offset, remaining_decode_tokens
)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids = self.input_ids[keep_indices]
position_ids = self.position_ids[keep_indices]
self.attention_mask = self.attention_mask[
keep_indices,
-(self.padding_right_offset + max_input_length) : (
self.attention_mask.shape[1] - self.padding_right_offset
)
+ new_padding_right_offset,
]
# Do the same for pixel_values and image_attention_mask
pixel_values = self.pixel_values[keep_indices]
self.image_attention_mask = self.image_attention_mask[
keep_indices,
-(self.padding_right_offset + max_input_length) : (
self.image_attention_mask.shape[1] - self.padding_right_offset
)
+ new_padding_right_offset,
:,
]
if self.image_hidden_states is None:
image_hidden_states = None
else:
image_hidden_states = self.image_hidden_states[keep_indices]
# Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) is tuple:
self.past_key_values = [list(layer) for layer in self.past_key_values]
# Update tensors in-place to allow incremental garbage collection
past_kv_length = max_input_length - 1
for layer in self.past_key_values:
past_keys, past_values = layer
if len(past_keys.shape) == 3:
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
if self.keys_head_dim_last:
layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
else:
layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
del past_keys
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
del past_values
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
self.requests = requests
self.requests_idx_mapping = requests_idx_mapping
self.input_ids = input_ids
self.pixel_values = pixel_values
self.image_hidden_states = image_hidden_states
self.position_ids = position_ids
self.all_input_ids = all_input_ids
self.input_lengths = input_lengths
self.prefix_offsets = prefix_offsets
self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
return self
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(
cls, batches: List["IdeficsCausalLMBatch"]
) -> "IdeficsCausalLMBatch":
# It adds new requests to the batch
# Used for padding
total_batch_size = 0
max_input_length = 0
max_num_images = 0
padding_right_offset = 0
for batch in batches:
total_batch_size += len(batch)
max_input_length = max(max_input_length, batch.max_input_length)
max_num_images = max(max_num_images, batch.pixel_values.size(1))
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes
requests = []
requests_idx_mapping = {}
input_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
max_tokens = 0
# Batch tensors
input_ids = None
attention_mask = None
position_ids = None
pixel_values = None
image_hidden_states = None
image_attention_mask = None
past_key_values = []
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + start_index
# Slicing end index for this batch
end_index = start_index + len(batch)
# We only concatenate batches that did at least one step
if batch.past_key_values is None:
raise ValueError("only concatenate prefilled batches")
# Create empty tensor
# input_ids is always of shape [batch_size, 1]
# We do not need to pad it
if input_ids is None:
input_ids = batch.input_ids.new_empty((total_batch_size, 1))
# Copy to correct indices
input_ids[start_index:end_index] = batch.input_ids
# Create padded tensor
if attention_mask is None:
attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_input_length + padding_right_offset),
)
curr_batch_max_num_images = batch.pixel_values.size(1)
if pixel_values is None:
pixel_values = batch.pixel_values.new_zeros(
(total_batch_size, max_num_images, 3, 224, 224)
)
pixel_values[start_index:end_index, :curr_batch_max_num_images] = (
batch.pixel_values
)
if image_attention_mask is None:
image_attention_mask = batch.image_attention_mask.new_zeros(
(
total_batch_size,
max_input_length + padding_right_offset,
max_num_images,
)
)
# We need to slice the attention mask to remove padding from previous steps
# and to remove unused allocated space
left_offset = max_input_length - batch.max_input_length
batch_left_offset = (
batch.attention_mask.shape[1]
- batch.max_input_length
- batch.padding_right_offset
)
attention_mask[
start_index:end_index,
left_offset:-padding_right_offset,
] = batch.attention_mask[
:,
batch_left_offset : -batch.padding_right_offset,
]
image_attention_mask[
start_index:end_index,
left_offset:-padding_right_offset,
:curr_batch_max_num_images,
] = batch.image_attention_mask[
:, batch_left_offset : -batch.padding_right_offset, :
]
# Create empty tensor
# position_ids is always of shape [batch_size, 1]
if position_ids is None:
position_ids = batch.position_ids.new_empty((total_batch_size, 1))
position_ids[start_index:end_index] = batch.position_ids
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
# And ensure that we can update tensors in-place
if isinstance(batch.past_key_values[0], tuple):
batch.past_key_values = [
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
for layer in batch.past_key_values
]
elif len(batch.past_key_values[0][0].shape) == 3:
for layer in batch.past_key_values:
for k, t in enumerate(layer):
layer[k] = t.view(len(batch), -1, *t.shape[-2:])
# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length - batch.max_input_length
) * len(batch)
start_index = end_index
first_past_kvs = batches[0].past_key_values
_, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape
padded_past_values_shape = (
total_batch_size,
num_heads,
max_input_length - 1,
head_dim,
)
if batches[0].keys_head_dim_last:
padded_past_keys_shape = padded_past_values_shape
else:
# seq_length is last for BLOOM
padded_past_keys_shape = (
total_batch_size,
num_heads,
head_dim,
max_input_length - 1,
)
# Iterate over attention layers
# Concatenate past key values layer by layer to allow incremental garbage collection
for j in range(len(first_past_kvs)):
padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
start_index = 0
for batch in batches:
past_keys = batch.past_key_values[j][0]
# Clear reference to the original tensor
batch.past_key_values[j][0] = None
# Slicing end index for this batch
end_index = start_index + len(batch)
# We slice the keys to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
if batch.keys_head_dim_last:
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
past_keys[:, :, -past_seq_len:, :]
)
else:
# BLOOM case
padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
past_keys[:, :, :, -past_seq_len:]
)
del past_keys
start_index = end_index
padded_past_values = first_past_kvs[j][1].new_zeros(
padded_past_values_shape
)
start_index = 0
for batch in batches:
past_values = batch.past_key_values[j][1]
# Clear reference to the original tensor
batch.past_key_values[j][1] = None
# Slicing end index for this batch
end_index = start_index + len(batch)
# We slice the past values to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
past_values[:, :, -past_seq_len:, :]
)
del past_values
# Update values
start_index = end_index
past_key_values.append([padded_past_keys, padded_past_values])
return cls(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
pixel_values=pixel_values,
image_hidden_states=image_hidden_states,
image_attention_mask=image_attention_mask,
past_key_values=past_key_values,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length,
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
max_tokens=max_tokens,
)
def __len__(self):
return len(self.requests)
class IdeficsCausalLM(Model):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed()
device = torch.device("hpu")
dtype = torch.bfloat16 if dtype is None else dtype
self.device, self.dtype = device, dtype
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
config.vision_config.quantize = quantize
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
weights_loader=weights_loader,
)
model = IdeficsForVisionText2Text(config, weights)
self.config = config
torch.distributed.barrier(group=self.process_group)
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@property
def batch_type(self) -> Type[IdeficsCausalLMBatch]:
return IdeficsCausalLMBatch
def forward(
self,
input_ids,
attention_mask,
position_ids,
pixel_values,
image_hidden_states,
image_attention_mask,
past_key_values: Optional = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward
kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"image_hidden_states": image_hidden_states,
"image_attention_mask": image_attention_mask,
"past_key_values": past_key_values,
"use_cache": True,
"return_dict": True,
}
if self.has_position_ids:
kwargs["position_ids"] = position_ids
outputs, speculative_logits = self.model.forward(**kwargs)
return (
outputs.logits,
speculative_logits,
outputs.past_key_values,
outputs.image_hidden_states,
)
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: IdeficsCausalLMBatch
) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch], Tuple[int, int]]:
start = time.time_ns()
# slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
if batch.image_attention_mask is None:
image_attention_mask = None
else:
if batch.input_ids.size(1) == 1:
# THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images),
# but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
# token need to attend to the encoder hidden states (i.e. the vision encoder)
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
image_attention_mask = batch.image_attention_mask[
:, -(batch.padding_right_offset + 1)
].unsqueeze(1)
else:
image_attention_mask = batch.image_attention_mask[
:, : -batch.padding_right_offset
]
logits, speculative_logits, past, image_hidden_states = self.forward(
input_ids=batch.input_ids,
attention_mask=attention_mask,
position_ids=batch.position_ids,
pixel_values=batch.pixel_values,
image_hidden_states=batch.image_hidden_states,
image_attention_mask=image_attention_mask,
past_key_values=batch.past_key_values,
)
# Hardcoded remove image tokens
logits[:, 32000:32001] = torch.finfo(logits.dtype).min
start_decode = time.time_ns()
# Results
generations: List[Generation] = []
stopped = True
# Zipped iterator
iterator = zip(
batch.requests,
batch.input_lengths,
batch.prefix_offsets,
batch.read_offsets,
logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
# For each member of the batch
for i, (
request,
input_length,
prefix_offset,
read_offset,
logits,
next_token_chooser,
stopping_criteria,
all_input_ids,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
all_input_ids.view(1, -1), logits[-1:, :]
)
# Append next token to all tokens
all_input_ids = torch.cat([all_input_ids, next_token_id])
new_input_length = input_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids[:, 0], prefix_offset, read_offset
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token_id_squeezed,
next_token_text,
)
if not stop:
stopped = False
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text, _, _ = self.decode_token(
all_input_ids[:, 0],
prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens
- 1,
read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
generated_text = None
# Prefill
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + torch.log_softmax(
logits, -1
).gather(1, all_input_ids[1:]).squeeze(1)[
-new_input_length:-1
].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = Tokens(
prefill_token_ids,
prefill_logprobs,
prefill_texts,
is_special=[],
)
else:
prefill_tokens = None
top_tokens = None
generation = Generation(
request.id,
prefill_tokens,
Tokens(
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
top_tokens,
)
generations.append(generation)
# Update values
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
next_token_id_squeezed.item()
)
batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.max_input_length = max(batch.max_input_length, new_input_length)
# We finished all generations in the batch; there is no next batch
if stopped:
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
# Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1]
# Update attention_mask as we added a new token to input_ids
batch.attention_mask[:, -batch.padding_right_offset] = 1
batch.image_attention_mask[:, -batch.padding_right_offset, :] = (
batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
)
# Decrease right offset
batch.padding_right_offset -= 1
# Update position_ids
batch.position_ids = batch.position_ids[:, -1:] + 1
# Update past key values
batch.past_key_values = past
batch.image_hidden_states = image_hidden_states
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)

View File

@ -1,814 +0,0 @@
import torch
import torch.distributed
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from typing import Optional
from text_generation_server.models.custom_modeling.mamba_modeling import (
MambaConfig,
)
from loguru import logger
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.models.globals import CUDA_GRAPHS, MEM_POOL
import time
from text_generation_server.models.custom_modeling.mamba_modeling import (
MambaModel,
InferenceParams,
)
from text_generation_server.models import Model
from typing import Any, List, Tuple, Type, Dict
from text_generation_server.models.types import (
Batch,
Tokens,
Generation,
GeneratedText,
)
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria
def new_inference_params(
n_blocks: int,
batch_size: int,
d_inner: int,
d_conv: int,
d_state: int,
seqlen_offset: int,
dtype: torch.dtype,
device: torch.device,
):
max_seqlen = 0
conv_states = torch.zeros(
(
n_blocks,
batch_size,
d_inner,
d_conv,
),
device=device,
dtype=dtype,
)
ssm_states = torch.zeros(
(
n_blocks,
batch_size,
d_inner,
d_state,
),
device=device,
dtype=dtype,
)
inference_params = InferenceParams(
max_seqlen=max_seqlen,
max_batch_size=batch_size,
seqlen_offset=seqlen_offset,
conv_states=conv_states,
ssm_states=ssm_states,
)
return inference_params
@dataclass
class MambaBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
requests_idx_mapping: Dict[int, int]
# Decoder values
input_ids: torch.Tensor
# All tokens
all_input_ids: List[torch.Tensor]
# Lengths of all generations present in the batch
input_lengths: List[int]
prefix_offsets: List[int]
read_offsets: List[int]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Metadata used for padding
max_input_length: int
padding_right_offset: int
# Maximum number of tokens this batch will grow to
max_tokens: int
# Past metadata
keys_head_dim_last: bool = True
# Inference params
inference_params: Optional[Dict[str, Any]] = None
def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch(
id=self.batch_id,
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
)
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "MambaBatch":
inputs = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
prefix_offsets = []
read_offsets = []
requests_idx_mapping = {}
# Parse batch
max_truncation = 0
padding_right_offset = 0
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(concat_text_chunks(r.input_chunks.chunks))
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
tokenized_inputs = tokenizer(
inputs,
return_tensors="pt",
padding=True,
return_token_type_ids=False,
truncation=True,
max_length=max_truncation,
).to(device)
for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(input_len - 5)
read_offsets.append(input_len)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
input_ids = tokenized_inputs["input_ids"]
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
# past_input_ids=None,
all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(),
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
)
def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]:
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
if len(request_ids) == len(self):
return self
keep_indices = []
# New values after filtering
requests_idx_mapping = {}
requests = []
input_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
max_input_length = 0
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
total_remaining_decode_tokens = 0
new_padding_right_offset = 0
indices = []
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
requests_idx_mapping[request_id] = i
keep_indices.append(idx)
requests.append(self.requests[idx])
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
all_input_ids.append(self.all_input_ids[idx])
request_input_length = self.input_lengths[idx]
input_lengths.append(request_input_length)
max_input_length = max(max_input_length, request_input_length)
indices.append(idx)
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
total_remaining_decode_tokens += remaining_decode_tokens
new_padding_right_offset = max(
new_padding_right_offset, remaining_decode_tokens
)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids = self.input_ids[keep_indices]
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
self.requests = requests
self.requests_idx_mapping = requests_idx_mapping
self.input_ids = input_ids
self.all_input_ids = all_input_ids
self.input_lengths = input_lengths
self.prefix_offsets = prefix_offsets
self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens
self.top_n_tokens_tensor = top_n_tokens_tensor
self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
# TODO
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
self.inference_params.conv_states = self.inference_params.conv_states[
:, indices
]
self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices]
return self
@classmethod
def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch":
# Used for padding
total_batch_size = 0
max_input_length = 0
padding_right_offset = 0
for batch in batches:
total_batch_size += len(batch)
max_input_length = max(max_input_length, batch.max_input_length)
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes
requests = []
requests_idx_mapping = {}
input_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
max_tokens = 0
seqlen_offset = 0
(n_blocks, _, d_inner, d_conv) = batches[0].inference_params.conv_states.shape
(_, _, _, d_state) = batches[0].inference_params.ssm_states.shape
dtype = batches[0].inference_params.conv_states.dtype
device = batches[0].inference_params.conv_states.device
inference_params = new_inference_params(
n_blocks=n_blocks,
batch_size=total_batch_size,
d_state=d_state,
d_conv=d_conv,
d_inner=d_inner,
seqlen_offset=seqlen_offset,
device=device,
dtype=dtype,
)
# Batch tensors
input_ids = None
top_n_tokens_tensor = None
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + start_index
# Slicing end index for this batch
end_index = start_index + len(batch)
# Create empty tensor
# input_ids is always of shape [batch_size, 1]
# We do not need to pad it
if input_ids is None:
input_ids = batch.input_ids.new_empty((total_batch_size, 1))
# Copy to correct indices
input_ids[start_index:end_index] = batch.input_ids
if top_n_tokens_tensor is None:
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length - batch.max_input_length
) * len(batch)
inference_params.max_seqlen = max(
inference_params.max_seqlen, batch.inference_params.max_seqlen
)
assert batch.inference_params.seqlen_offset != 0, "Invalid seqlen offset"
inference_params.seqlen_offset = max(
inference_params.seqlen_offset, batch.inference_params.seqlen_offset
)
inference_params.conv_states[:, start_index:end_index] = (
batch.inference_params.conv_states
)
inference_params.ssm_states[:, start_index:end_index] = (
batch.inference_params.ssm_states
)
start_index = end_index
return cls(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length,
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
max_tokens=max_tokens,
inference_params=inference_params,
)
def __len__(self):
return len(self.requests)
class Mamba(Model):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.quantize = quantize
self.process_group, _rank, world_size = initialize_torch_distributed()
if world_size > 1:
raise RuntimeError("Mamba does not support Tensor Parallelism (TP)")
self.cuda_graphs = {}
if torch.cuda.is_available():
device = torch.device("cuda")
# Bf16 is important. In f16 accumulations in the matmul are causing
# differences while the server is under load.
# This is detectable by the integration load test
dtype = torch.bfloat16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
"EleutherAI/gpt-neox-20b",
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = MambaConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
tokenizer.bos_token_id = config.bos_token_id
tokenizer.eos_token_id = config.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
weights_loader=weights_loader,
)
model = MambaModel(config, weights)
torch.distributed.barrier(group=self.process_group)
super(Mamba, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)
@property
def batch_type(self) -> Type[MambaBatch]:
return MambaBatch
def warmup(self, batch) -> Optional[int]:
# TODO: implement warmup for Mamba if needed
if CUDA_GRAPHS:
if self.speculate is None or self.speculate == 0:
try:
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
# Warmup cuda graphs
for bs in CUDA_GRAPHS:
self.cuda_graph_warmup(bs)
except Exception:
logger.exception("Decode cuda graph warmup failed")
else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
return None
def cuda_graph_warmup(self, batch_size: int):
input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)
n_blocks = len(self.model.blocks)
d_state = self.model.config.d_state
d_conv = self.model.config.d_conv
# Inner takes the expand multiplication
d_inner = self.model.config.d_inner
# Important seqlen_offset to go through the update mecanism with the state
seqlen_offset = 1
inference_params = new_inference_params(
n_blocks=n_blocks,
batch_size=batch_size,
d_state=d_state,
d_conv=d_conv,
d_inner=d_inner,
seqlen_offset=seqlen_offset,
device=self.device,
dtype=self.dtype,
)
graph = torch.cuda.CUDAGraph()
torch.cuda.synchronize()
# Run once outside to warmup
self.model.forward(input_ids=input_ids, inference_params=inference_params)
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL):
logits, speculative_logits = self.model.forward(
input_ids=input_ids, inference_params=inference_params
)
torch.cuda.synchronize()
graph_dict = {
"input_ids": input_ids,
"inference_params": inference_params,
"graph": graph,
"logits": logits,
"speculative_logits": speculative_logits,
}
self.cuda_graphs[batch_size] = graph_dict
def tunableop_warmup(self, batch_size: int, seqlen: int):
input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)
n_blocks = len(self.model.blocks)
d_state = self.model.config.d_state
d_conv = self.model.config.d_conv
# Inner takes the expand multiplication
d_inner = self.model.config.d_inner
# Important seqlen_offset to go through the update mecanism with the state
seqlen_offset = 1
inference_params = new_inference_params(
n_blocks=n_blocks,
batch_size=seqlen,
d_state=d_state,
d_conv=d_conv,
d_inner=d_inner,
seqlen_offset=seqlen_offset,
device=self.device,
dtype=self.dtype,
)
self.model.forward(input_ids=input_ids, inference_params=inference_params)
def forward(
self, input_ids: torch.Tensor, inference_params: Any
) -> Tuple[torch.Tensor, torch.Tensor]:
bs = input_ids.shape[0]
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None)
is_prefill = inference_params is None or inference_params.seqlen_offset == 0
if is_prefill or cuda_graph is None:
return self.model(
input_ids,
inference_params=inference_params,
)
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][:bs] = input_ids
cuda_graph["inference_params"].conv_states[
:, :bs
] = inference_params.conv_states
cuda_graph["inference_params"].ssm_states[:, :bs] = inference_params.ssm_states
# Replay the graph
cuda_graph["graph"].replay()
inference_params.conv_states.copy_(
cuda_graph["inference_params"].conv_states[:, :bs]
)
inference_params.ssm_states.copy_(
cuda_graph["inference_params"].ssm_states[:, :bs]
)
# Slice output to the correct shape
speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
start = time.time_ns()
input_ids = (
batch.input_ids
) # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
batch_size, max_seqlen = input_ids.shape
# Inference params
if batch.inference_params is None:
# 0 is important here
seqlen_offset = 0
n_blocks = len(self.model.blocks)
d_state = self.model.config.d_state
d_conv = self.model.config.d_conv
d_inner = self.model.config.d_inner
inference_params = new_inference_params(
n_blocks=n_blocks,
batch_size=batch_size,
d_state=d_state,
d_conv=d_conv,
d_inner=d_inner,
seqlen_offset=seqlen_offset,
device=self.device,
dtype=self.dtype,
)
batch.inference_params = inference_params
# Forward pass
logits, speculative_logits = self.forward(
input_ids, inference_params=batch.inference_params
)
# batch.inference_params = new_inference_params
# Results
generations: List[Generation] = []
stopped = True
# Speculation is not active for causal
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.log_softmax(logits[:, -1], -1),
accepted_ids,
)
start_decode = time.time_ns()
# Zipped iterator
iterator = zip(
batch.requests,
batch.input_lengths,
batch.prefix_offsets,
batch.read_offsets,
logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
batch.top_n_tokens,
batch_top_token_ids,
batch_top_token_logprobs,
)
# For each member of the batch
for i, (
request,
input_length,
prefix_offset,
read_offset,
logits,
next_token_chooser,
stopping_criteria,
all_input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
all_input_ids.view(1, -1), logits[-1:, :]
)
# Append next token to all tokens
all_input_ids = torch.cat([all_input_ids, next_token_id])
new_input_length = input_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids[:, 0], prefix_offset, read_offset
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token_id_squeezed,
next_token_text,
)
if not stop:
stopped = False
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text, _, _ = self.decode_token(
all_input_ids[:, 0],
prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens
- 1,
read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
generated_text = None
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + torch.log_softmax(
logits, -1
).gather(1, all_input_ids[1:]).squeeze(1)[
-new_input_length:-1
].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = Tokens(
prefill_token_ids,
prefill_logprobs,
prefill_texts,
is_special=[],
)
else:
prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
else:
top_tokens = None
generation = Generation(
request.id,
prefill_tokens,
Tokens(
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
top_tokens,
)
generations.append(generation)
# Update values
batch.next_token_choosers[i] = batch.next_token_choosers[
i
].advance_grammar(next_token_id_squeezed.item())
batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.max_input_length = max(batch.max_input_length, new_input_length)
# We finished all generations in the batch; there is no next batch
if stopped:
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
# Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1]
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)

View File

@ -46,10 +46,17 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
aspect_ratio_mask: Optional[torch.Tensor] = None
cross_attention_states: Optional[torch.Tensor] = None
def prepare_for_prefill(
self, max_padded_input_len, max_padded_bs, max_total_tokens
):
super(FlashVlmCausalLMBatch, self).prepare_for_prefill(
max_padded_input_len, max_padded_bs, max_total_tokens
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches):
batch = super().concatenate(batches)
def concatenate(cls, batches, padded_total_bs: int = 0):
batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs)
batch.pixel_values = None
batch.pixel_attention_mask = None
@ -73,7 +80,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
assert self.image_indices is not None
batch = super().filter(request_ids)
batch = super(FlashVlmCausalLMBatch, self).filter(request_ids)
assert self.image_indices is not None
indices = []
for i, request_id in enumerate(request_ids):
@ -99,6 +106,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
]
else:
batch.cross_attention_states = None
batch.pixel_values = None
return batch
@classmethod
@ -228,6 +236,10 @@ def generate_cross_attention_states(
class FlashMllamaCausalLM(FlashVlmCausalLM):
def set_inputs_embeds(self, batch):
# Set the input embeddings to None, as we are using the input_ids for the model
batch.inputs_embeds = None
def warmup_decode(
self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch
):

View File

@ -1,71 +0,0 @@
from io import BytesIO
from PIL import Image
import torch
import torch.distributed
from opentelemetry import trace
from typing import Iterable
from text_generation_server.models.flash_vlm_causal_lm import (
FlashVlmCausalLMBatch,
image_text_replacement,
)
from text_generation_server.pb.generate_pb2 import Request
tracer = trace.get_tracer(__name__)
class PaliGemmaBatch(FlashVlmCausalLMBatch):
@classmethod
def batch_tokenized_inputs(
cls, requests: Iterable[Request], tokenizer, processor, config
):
batch_inputs = []
image_inputs = []
max_truncation = 0
for r in requests:
full_text = ""
image_id = 0
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
full_text += "<bos>" + chunk.text + "\n"
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
# TODO do_convert_RGB should be on by default ?
image = image.convert("RGB")
image_input = processor.image_processor(image, return_tensors="pt")
full_text += image_text_replacement(
processor, image_input, config, image_id
)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs,
truncation=True,
max_length=max_truncation,
add_special_tokens=False,
)["input_ids"]
if image_inputs:
image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
}
if "pixel_attention_mask" in image_input:
new_image_inputs["pixel_attention_mask"] = torch.cat(
[img["pixel_attention_mask"] for img in image_inputs], dim=0
)
if "image_sizes" in image_input:
new_image_inputs["image_sizes"] = torch.cat(
[img["image_sizes"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
else:
image_inputs = None
return batch_tokenized_inputs, image_inputs

View File

@ -1,47 +0,0 @@
import torch
from dataclasses import dataclass
from typing import List, Optional, Type
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
@dataclass
class StarCoderCausalLMBatch(CausalLMBatch):
past_key_values: Optional[List[torch.Tensor]]
def detach_kv_cache(self):
past_keys = []
past_values = []
last_dim = int(self.past_key_values[0].size(dim=-1) / 2)
for key_value in self.past_key_values:
past_keys.append(key_value.split((last_dim, last_dim), dim=-1)[0])
past_values.append(key_value.split((last_dim, last_dim), dim=-1)[1])
del self.past_key_values
return past_keys, past_values
def attach_kv_cache(self, past_keys, past_values):
self.past_key_values = [
torch.cat((key, value), dim=-1)
for key, value in zip(past_keys, past_values)
]
class StarCoder(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
super(StarCoder, self).__init__(
model_id=model_id,
revision=revision,
dtype=dtype,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return StarCoderCausalLMBatch

View File

@ -7,13 +7,5 @@ if [[ "$*" == *"--sharded true"* ]]; then
echo 'setting PT_HPU_ENABLE_LAZY_COLLECTIVES=1 for sharding'
export PT_HPU_ENABLE_LAZY_COLLECTIVES=1
fi
# Check if ATTENTION environment variable is set to paged
if [[ "$ATTENTION" == "paged" ]]; then
# Check if Llama-4 is in the command line arguments
if [[ "$*" == *"Llama-4"* || "$*" == *"Qwen3"* ]]; then
echo 'ATTENTION=paged and Llama-4 or Qwen3 detected'
pip install git+https://github.com/huggingface/transformers.git@29338949
fi
fi
text-generation-launcher $@

View File

@ -7,7 +7,8 @@ from typing import List, Optional, Tuple
import torch
from loguru import logger
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from optimum.neuron.configuration_utils import NeuronConfig
from transformers.generation import GenerationConfig
from optimum.neuron import NeuronModelForCausalLM
@ -175,6 +176,12 @@ class Slot:
self._generation_config.top_p = request.parameters.top_p
if request.parameters.typical_p != 0:
self._generation_config.typical_p = request.parameters.typical_p
else:
# Set the sampling parameters to emulate greedy decoding when using on-device sampling
self._generation_config.temperature = 1.0
self._generation_config.top_k = 1
self._generation_config.top_p = 1.0
self._generation_config.typical_p = 1.0
if request.parameters.repetition_penalty != 0:
self._generation_config.repetition_penalty = (
request.parameters.repetition_penalty
@ -211,19 +218,11 @@ class Slot:
self._mask = attention_mask.clone()
self._selector = selector
def pause(self, reset_on_pause: bool):
def pause(self):
"""Mark the current slot as paused for generation.
Note that the KV cache for this slot will still be filled.
"""
if reset_on_pause:
# Drop the last token as it will be added back when resuming the slot
self._generated_tokens -= 1
# Since generated tokens are now part of the prefill, we need to reevaluate
# max_new_tokens for the next generation
self._generation_config.max_new_tokens = (
self._max_new_tokens - self._generated_tokens
)
self._state = Slot.State.PAUSE
def resume(self):
@ -340,16 +339,27 @@ class NeuronGenerator(Generator):
tokenizer: PreTrainedTokenizerBase,
):
self.model = model
self.rebuild_cache_on_prefill = not self.model.continuous_batching
if not isinstance(self.model, NeuronModelForCausalLM):
raise ValueError("The model must be a NeuronModelForCausalLM.")
if not model.neuron_config.continuous_batching:
raise ValueError(
"The neuron model must be compiled with continuous_batching=True."
)
# Specify padding and truncation options for decoder-only architecture
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"
self.tokenizer = tokenizer
self.special_tokens = self.tokenizer.all_special_ids
self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)]
self.slots = [
Slot(i, tokenizer) for i in range(self.model.neuron_config.batch_size)
]
self.batch_id = 0
@property
def on_device_sampling(self) -> bool:
return getattr(self.model.neuron_config, "on_device_sampling", False)
@property
def info(self) -> InfoResponse:
"""Returns the expected InfoResponse."""
@ -371,14 +381,22 @@ class NeuronGenerator(Generator):
The maximum number of tokens the model supports.
"""
# Just check that the warmup request parameters match the model capacity
batch_size = self.model.batch_size
batch_size = self.model.neuron_config.batch_size
if len(batch.requests) > batch_size:
raise ValueError(
f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE."
f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model.neuron_config.batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE."
)
self.prefill(batch)
self.clear()
return self.model.batch_size * self.model.max_length
return (
self.model.neuron_config.batch_size
* self.model.neuron_config.sequence_length
)
def max_prefill_length(self) -> int:
if hasattr(self.model.neuron_config, "max_context_length"):
return self.model.neuron_config.max_context_length
return self.model.neuron_config.sequence_length
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
"""Prefill new requests.
@ -398,7 +416,7 @@ class NeuronGenerator(Generator):
if len(empty_slots) < len(batch.requests):
raise ValueError(
f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots."
f" Please align max_batch_size with the static batch size: {self.model.batch_size}."
f" Please align max_batch_size with the static batch size: {self.model.neuron_config.batch_size}."
)
# Assign each request to an empty slot
logger.debug(
@ -412,14 +430,8 @@ class NeuronGenerator(Generator):
logger.debug(
f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
)
if self.rebuild_cache_on_prefill:
# We will clear pending slots and prefill all slots
prefill_slots = self.slots
seq_ids = None
else:
# We only need to pass inputs for the new requests
prefill_slots = new_slots
seq_ids = torch.tensor([slot.id for slot in prefill_slots])
prefill_slots = new_slots
seq_ids = torch.tensor([slot.id for slot in prefill_slots])
# Reconstruct the full inputs (without padding) as seen by the model.
# This comprises:
# - the inputs for new requests,
@ -431,8 +443,10 @@ class NeuronGenerator(Generator):
inputs.append(slot.cached_text)
# Apply truncation, making sure we fit into static dimensions
if slot.truncate == 0:
max_length = self.model.max_length
elif slot.truncate > max_length and slot.truncate < self.model.max_length:
max_length = self.max_prefill_length()
elif (
slot.truncate > max_length and slot.truncate < self.max_prefill_length()
):
max_length = slot.truncate
# Tokenize with padding and truncation
padded_inputs = self.tokenizer(
@ -444,13 +458,12 @@ class NeuronGenerator(Generator):
)
input_ids = padded_inputs.input_ids
attention_mask = padded_inputs.attention_mask
sampling_params = (
torch.zeros(input_ids.shape[0], 3) if self.on_device_sampling else None
)
# Pause previously active slots during generation
next_tokens = []
for slot in active_slots:
slot.pause(reset_on_pause=self.rebuild_cache_on_prefill)
if self.rebuild_cache_on_prefill:
# The slot will be reset, so we need to store its next token
next_tokens.append(slot.next_token)
slot.pause()
# Each slot must be reset with the padded inputs and masks
for i, slot in enumerate(prefill_slots):
if slot.state != slot.state.EMPTY:
@ -464,29 +477,33 @@ class NeuronGenerator(Generator):
slot_input_ids,
slot.generation_config,
self.model,
self.model.max_length,
self.model.neuron_config.sequence_length,
tokenizer=self.tokenizer,
seed=slot.seed,
)
slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
slot_attention_mask = attention_mask[i]
slot.reset(slot_input_ids, slot_attention_mask, selector)
if sampling_params is not None:
sampling_params[i, 0] = slot.generation_config.top_k
sampling_params[i, 1] = slot.generation_config.top_p
sampling_params[i, 2] = slot.generation_config.temperature
# Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored,
# as they have already been generated and sent back in the last decode.
model_inputs = self.model.prepare_inputs_for_prefill(
input_ids, attention_mask, seq_ids
input_ids,
attention_mask=attention_mask,
seq_ids=seq_ids,
sampling_params=sampling_params,
)
logits = self.model(**model_inputs)[0]
tokens_or_logits = self.model(**model_inputs)[0]
generation, next_batch = self._generate_token(
prefill_slots, self.batch_id, logits, input_ids
prefill_slots, self.batch_id, tokens_or_logits, input_ids
)
self.batch_id += 1
# Reactivate previously active slots for the next decode
for i, slot in enumerate(active_slots):
slot.resume()
if self.rebuild_cache_on_prefill:
# Append back the next token
slot.append(next_tokens[i])
logger.debug("Model ready for decoding")
if next_batch is not None:
logger.debug(
@ -530,12 +547,8 @@ class NeuronGenerator(Generator):
raise ValueError(
"Unable to decode tokens for non-prefilled batches (probably due to a previous failure)"
)
if self.model.continuous_batching:
decode_slots = active_slots
seq_ids = torch.tensor([slot.id for slot in decode_slots])
else:
decode_slots = self.slots
seq_ids = None
decode_slots = active_slots
seq_ids = torch.tensor([slot.id for slot in decode_slots])
# Reconstruct input_ids and attention_mask from decode slots
n_slots = len(decode_slots)
input_ids = torch.full(
@ -545,22 +558,32 @@ class NeuronGenerator(Generator):
for slot in decode_slots:
max_length = max(max_length, slot.attention_mask.size(-1))
attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64)
sampling_params = torch.zeros(n_slots, 3) if self.on_device_sampling else None
for i, slot in enumerate(decode_slots):
if slot.state != Slot.State.EMPTY:
# input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
input_ids[i, 0] = slot.next_token
attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask
if sampling_params is not None:
sampling_params[i, 0] = slot.generation_config.top_k
sampling_params[i, 1] = slot.generation_config.top_p
sampling_params[i, 2] = slot.generation_config.temperature
model_inputs = self.model.prepare_inputs_for_decode(
input_ids, attention_mask, seq_ids
input_ids,
attention_mask=attention_mask,
seq_ids=seq_ids,
sampling_params=sampling_params,
)
tokens_or_logits = self.model(**model_inputs)[0]
return self._generate_token(
decode_slots, next_batch_id, tokens_or_logits, input_ids
)
logits = self.model(**model_inputs)[0]
return self._generate_token(decode_slots, next_batch_id, logits, input_ids)
def _generate_token(
self,
slots: List[Slot],
next_batch_id: int,
logits: torch.Tensor,
tokens_or_logits: torch.Tensor,
input_ids: torch.LongTensor,
) -> Tuple[List[Generation], CachedBatch]:
generations = []
@ -569,9 +592,12 @@ class NeuronGenerator(Generator):
if slot.state != Slot.State.READY:
continue
request_id = slot.request_id
next_token_logits = logits[i : i + 1, -1, :]
slot_input_ids = input_ids[i : i + 1, :]
next_token = slot.select(slot_input_ids, next_token_logits)
if self.on_device_sampling:
next_token = tokens_or_logits[i]
else:
next_token_logits = tokens_or_logits[i : i + 1, -1, :]
next_token = slot.select(slot_input_ids, next_token_logits)
next_token_text = slot.append(next_token)
generated_text = None
finish_reason = None
@ -622,7 +648,7 @@ class NeuronGenerator(Generator):
def _cached_batch(self, batch_id: int, request_ids: List):
size = len(request_ids)
max_tokens = size * self.model.max_length
max_tokens = size * self.model.neuron_config.sequence_length
return CachedBatch(
id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens
)
@ -671,8 +697,16 @@ class NeuronGenerator(Generator):
Returns:
A NeuronGenerator.
"""
config = AutoConfig.from_pretrained(model_id)
neuron_config = getattr(config, "neuron", None)
try:
neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision)
except Exception as e:
logger.debug(
"NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
model_id,
revision,
e,
)
neuron_config = None
start = time.time()
if neuron_config is None:
export_kwargs = get_export_kwargs_from_env()

View File

@ -6,10 +6,12 @@ from typing import Optional
from huggingface_hub import snapshot_download
from huggingface_hub.constants import HF_HUB_CACHE
from loguru import logger
from transformers import AutoConfig
from optimum.neuron import NeuronModelForCausalLM
from optimum.neuron.utils import get_hub_cached_entries
from optimum.neuron.cache import get_hub_cached_entries
from optimum.neuron.configuration_utils import NeuronConfig
from .tgi_env import check_env_and_neuron_config_compatibility
def get_export_kwargs_from_env():
@ -24,7 +26,6 @@ def get_export_kwargs_from_env():
num_cores = int(num_cores)
auto_cast_type = os.environ.get("HF_AUTO_CAST_TYPE", None)
return {
"task": "text-generation",
"batch_size": batch_size,
"sequence_length": sequence_length,
"num_cores": num_cores,
@ -32,20 +33,15 @@ def get_export_kwargs_from_env():
}
def is_cached(model_id, neuron_config):
def is_cached(model_id):
# Look for cached entries for the specified model
in_cache = False
entries = get_hub_cached_entries(model_id, "inference")
entries = get_hub_cached_entries(model_id)
# Look for compatible entries
for entry in entries:
compatible = True
for key, value in neuron_config.items():
# Only weights can be different
if key in ["checkpoint_id", "checkpoint_revision"]:
continue
if entry[key] != value:
compatible = False
if compatible:
if check_env_and_neuron_config_compatibility(
entry, check_compiler_version=True
):
in_cache = True
break
return in_cache
@ -87,8 +83,16 @@ def fetch_model(
revision = None
# Download the model from the Hub (HUGGING_FACE_HUB_TOKEN must be set for a private or gated model)
# Note that the model may already be present in the cache.
config = AutoConfig.from_pretrained(model_id, revision=revision)
neuron_config = getattr(config, "neuron", None)
try:
neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision)
except Exception as e:
logger.debug(
"NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
model_id,
revision,
e,
)
neuron_config = None
if neuron_config is not None:
if os.path.isdir(model_id):
return model_id
@ -99,16 +103,11 @@ def fetch_model(
log_cache_size()
return snapshot_download(model_id, revision=revision, ignore_patterns="*.bin")
# Model needs to be exported: look for compatible cached entries on the hub
export_kwargs = get_export_kwargs_from_env()
export_config = NeuronModelForCausalLM.get_export_config(
model_id, config, revision=revision, **export_kwargs
)
neuron_config = export_config.neuron
if not is_cached(model_id, neuron_config):
if not is_cached(model_id):
hub_cache_url = "https://huggingface.co/aws-neuron/optimum-neuron-cache"
neuron_export_url = "https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-neuronx-tgi"
error_msg = (
f"No cached version found for {model_id} with {neuron_config}."
f"No cached version found for {model_id} with {get_export_kwargs_from_env()}."
f"You can start a discussion to request it on {hub_cache_url}"
f"Alternatively, you can export your own neuron model as explained in {neuron_export_url}"
)
@ -121,8 +120,10 @@ def fetch_model(
# Prefetch weights, tokenizer and generation config so that they are in cache
log_cache_size()
start = time.time()
snapshot_download(model_id, revision=revision, ignore_patterns="*.bin")
snapshot_path = snapshot_download(
model_id, revision=revision, ignore_patterns="*.bin"
)
end = time.time()
logger.info(f"Model weights fetched in {end - start:.2f} s.")
log_cache_size()
return model_id
return snapshot_path

View File

@ -6,12 +6,11 @@ import os
import sys
from typing import Any, Dict, List, Optional
from huggingface_hub import constants
from transformers import AutoConfig
from optimum.neuron.modeling_decoder import get_available_cores
from optimum.neuron.utils import get_hub_cached_entries
from optimum.neuron.cache import get_hub_cached_entries
from optimum.neuron.configuration_utils import NeuronConfig
from optimum.neuron.utils.version_utils import get_neuronxcc_version
from optimum.neuron.utils import map_torch_dtype
logger = logging.getLogger(__name__)
@ -24,15 +23,9 @@ tgi_router_env_vars = [
]
tgi_server_env_vars = ["HF_NUM_CORES", "HF_AUTO_CAST_TYPE"]
env_config_peering = [
("MAX_BATCH_SIZE", "batch_size"),
("MAX_TOTAL_TOKENS", "sequence_length"),
("HF_AUTO_CAST_TYPE", "auto_cast_type"),
("HF_NUM_CORES", "num_cores"),
]
# By the end of this script all env var should be specified properly
env_vars = tgi_server_env_vars + tgi_router_env_vars
tgi_env_vars = tgi_server_env_vars + tgi_router_env_vars
available_cores = get_available_cores()
neuronxcc_version = get_neuronxcc_version()
@ -93,9 +86,17 @@ def parse_cmdline_and_set_env(argv: List[str] = None) -> argparse.Namespace:
def neuron_config_to_env(neuron_config):
if isinstance(neuron_config, NeuronConfig):
neuron_config = neuron_config.to_dict()
with open(os.environ["ENV_FILEPATH"], "w") as f:
for env_var, config_key in env_config_peering:
f.write("export {}={}\n".format(env_var, neuron_config[config_key]))
f.write("export MAX_BATCH_SIZE={}\n".format(neuron_config["batch_size"]))
f.write("export MAX_TOTAL_TOKENS={}\n".format(neuron_config["sequence_length"]))
f.write("export HF_NUM_CORES={}\n".format(neuron_config["tp_degree"]))
config_key = (
"auto_cast_type" if "auto_cast_type" in neuron_config else "torch_dtype"
)
auto_cast_type = neuron_config[config_key]
f.write("export HF_AUTO_CAST_TYPE={}\n".format(auto_cast_type))
max_input_tokens = os.getenv("MAX_INPUT_TOKENS")
if not max_input_tokens:
max_input_tokens = int(neuron_config["sequence_length"]) // 2
@ -111,7 +112,7 @@ def neuron_config_to_env(neuron_config):
def sort_neuron_configs(dictionary):
return -dictionary["num_cores"], -dictionary["batch_size"]
return -dictionary["tp_degree"], -dictionary["batch_size"]
def lookup_compatible_cached_model(
@ -119,7 +120,7 @@ def lookup_compatible_cached_model(
) -> Optional[Dict[str, Any]]:
# Reuse the same mechanic as the one in use to configure the tgi server part
# The only difference here is that we stay as flexible as possible on the compatibility part
entries = get_hub_cached_entries(model_id, "inference")
entries = get_hub_cached_entries(model_id)
logger.debug(
"Found %d cached entries for model %s, revision %s",
@ -155,15 +156,15 @@ def lookup_compatible_cached_model(
def check_env_and_neuron_config_compatibility(
neuron_config: Dict[str, Any], check_compiler_version: bool
neuron_config_dict: Dict[str, Any], check_compiler_version: bool
) -> bool:
logger.debug(
"Checking the provided neuron config %s is compatible with the local setup and provided environment",
neuron_config,
neuron_config_dict,
)
# Local setup compat checks
if neuron_config["num_cores"] > available_cores:
if neuron_config_dict["tp_degree"] > available_cores:
logger.debug(
"Not enough neuron cores available to run the provided neuron config"
)
@ -171,33 +172,65 @@ def check_env_and_neuron_config_compatibility(
if (
check_compiler_version
and neuron_config["compiler_version"] != neuronxcc_version
and neuron_config_dict["neuronxcc_version"] != neuronxcc_version
):
logger.debug(
"Compiler version conflict, the local one (%s) differs from the one used to compile the model (%s)",
neuronxcc_version,
neuron_config["compiler_version"],
neuron_config_dict["neuronxcc_version"],
)
return False
for env_var, config_key in env_config_peering:
neuron_config_value = str(neuron_config[config_key])
env_value = os.getenv(env_var, str(neuron_config_value))
batch_size = os.getenv("MAX_BATCH_SIZE", None)
if batch_size is not None and neuron_config_dict["batch_size"] < int(batch_size):
logger.debug(
"The provided MAX_BATCH_SIZE (%s) is higher than the neuron config batch size (%s)",
os.getenv("MAX_BATCH_SIZE"),
neuron_config_dict["batch_size"],
)
return False
max_total_tokens = os.getenv("MAX_TOTAL_TOKENS", None)
if max_total_tokens is not None and neuron_config_dict["sequence_length"] < int(
max_total_tokens
):
logger.debug(
"The provided MAX_TOTAL_TOKENS (%s) is higher than the neuron config sequence length (%s)",
max_total_tokens,
neuron_config_dict["sequence_length"],
)
return False
num_cores = os.getenv("HF_NUM_CORES", None)
if num_cores is not None and neuron_config_dict["tp_degree"] < int(num_cores):
logger.debug(
"The provided HF_NUM_CORES (%s) is higher than the neuron config tp degree (%s)",
num_cores,
neuron_config_dict["tp_degree"],
)
return False
auto_cast_type = os.getenv("HF_AUTO_CAST_TYPE", None)
if auto_cast_type is not None:
config_key = (
"auto_cast_type"
if "auto_cast_type" in neuron_config_dict
else "torch_dtype"
)
neuron_config_value = map_torch_dtype(str(neuron_config_dict[config_key]))
env_value = map_torch_dtype(auto_cast_type)
if env_value != neuron_config_value:
logger.debug(
"The provided env var '%s' and the neuron config '%s' param differ (%s != %s)",
env_var,
config_key,
"The provided auto cast type and the neuron config param differ (%s != %s)",
env_value,
neuron_config_value,
)
return False
max_input_tokens = int(
os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0))
)
if max_input_tokens > 0:
sequence_length = neuron_config["sequence_length"]
if hasattr(neuron_config_dict, "max_context_length"):
sequence_length = neuron_config_dict["max_context_length"]
else:
sequence_length = neuron_config_dict["sequence_length"]
if max_input_tokens >= sequence_length:
logger.debug(
"Specified max input tokens is not compatible with config sequence length ( %s >= %s)",
@ -211,38 +244,29 @@ def check_env_and_neuron_config_compatibility(
def get_env_dict() -> Dict[str, str]:
d = {}
for k in env_vars:
for k in tgi_env_vars:
d[k] = os.getenv(k)
return d
def main():
"""
This script determines proper default TGI env variables for the neuron precompiled models to
work properly
:return:
"""
args = parse_cmdline_and_set_env()
for env_var in env_vars:
if not os.getenv(env_var):
break
else:
logger.info(
"All env vars %s already set, skipping, user know what they are doing",
env_vars,
def get_neuron_config_for_model(
model_name_or_path: str, revision: Optional[str] = None
) -> NeuronConfig:
try:
neuron_config = NeuronConfig.from_pretrained(
model_name_or_path, revision=revision
)
sys.exit(0)
cache_dir = constants.HF_HUB_CACHE
logger.info("Cache dir %s, model %s", cache_dir, args.model_id)
config = AutoConfig.from_pretrained(args.model_id, revision=args.revision)
neuron_config = getattr(config, "neuron", None)
except Exception as e:
logger.debug(
"NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
model_name_or_path,
revision,
e,
)
neuron_config = None
if neuron_config is not None:
compatible = check_env_and_neuron_config_compatibility(
neuron_config, check_compiler_version=False
neuron_config.to_dict(), check_compiler_version=False
)
if not compatible:
env_dict = get_env_dict()
@ -252,17 +276,6 @@ def main():
logger.error(msg)
raise Exception(msg)
else:
neuron_config = lookup_compatible_cached_model(args.model_id, args.revision)
neuron_config = lookup_compatible_cached_model(model_name_or_path, revision)
if not neuron_config:
msg = (
"No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}"
).format(get_env_dict(), available_cores, neuronxcc_version)
logger.error(msg)
raise Exception(msg)
neuron_config_to_env(neuron_config)
if __name__ == "__main__":
main()
return neuron_config

View File

@ -4,14 +4,12 @@ import subprocess
import sys
from tempfile import TemporaryDirectory
import huggingface_hub
import os
import pytest
from transformers import AutoTokenizer
from optimum.neuron import NeuronModelForCausalLM
from optimum.neuron.utils import synchronize_hub_cache
from optimum.neuron.version import __sdk_version__ as sdk_version
from optimum.neuron.version import __version__ as version
from optimum.neuron.cache import synchronize_hub_cache
logging.basicConfig(
@ -21,30 +19,14 @@ logging.basicConfig(
)
logger = logging.getLogger(__file__)
OPTIMUM_CACHE_REPO_ID = "optimum-internal-testing/neuron-testing-cache"
# All model configurations below will be added to the neuron_model_config fixture
MODEL_CONFIGURATIONS = {
"gpt2": {
"model_id": "gpt2",
"export_kwargs": {
"batch_size": 4,
"sequence_length": 1024,
"num_cores": 2,
"auto_cast_type": "fp16",
},
},
"llama": {
"model_id": "NousResearch/Hermes-2-Theta-Llama-3-8B",
"export_kwargs": {
"batch_size": 4,
"sequence_length": 2048,
"num_cores": 2,
"auto_cast_type": "fp16",
},
},
"mistral": {
"model_id": "optimum/mistral-1.1b-testing",
"model_id": "unsloth/Llama-3.2-1B-Instruct",
"export_kwargs": {
"batch_size": 4,
"sequence_length": 4096,
@ -58,7 +40,7 @@ MODEL_CONFIGURATIONS = {
"batch_size": 4,
"sequence_length": 4096,
"num_cores": 2,
"auto_cast_type": "fp16",
"auto_cast_type": "bf16",
},
},
"granite": {
@ -73,12 +55,6 @@ MODEL_CONFIGURATIONS = {
}
def get_hub_neuron_model_id(config_name: str):
return (
f"optimum-internal-testing/neuron-testing-{version}-{sdk_version}-{config_name}"
)
def export_model(model_id, export_kwargs, neuron_model_path):
export_command = [
"optimum-cli",
@ -104,57 +80,35 @@ def export_model(model_id, export_kwargs, neuron_model_path):
def neuron_model_config(request):
"""Expose a pre-trained neuron model
The fixture first makes sure the following model artifacts are present on the hub:
- exported neuron model under optimum-internal-testing/neuron-testing-<version>-<name>,
- cached artifacts under optimum-internal-testing/neuron-testing-cache.
If not, it will export the model and push it to the hub.
It then fetches the model locally and return a dictionary containing:
The fixture exports a model locally and returns a dictionary containing:
- a configuration name,
- the original model id,
- the export parameters,
- the neuron model id,
- the neuron model local path.
For each exposed model, the local directory is maintained for the duration of the
test session and cleaned up afterwards.
The hub model artifacts are never cleaned up and persist accross sessions.
They must be cleaned up manually when the optimum-neuron version changes.
"""
config_name = request.param
model_config = copy.deepcopy(MODEL_CONFIGURATIONS[request.param])
model_id = model_config["model_id"]
export_kwargs = model_config["export_kwargs"]
neuron_model_id = get_hub_neuron_model_id(config_name)
with TemporaryDirectory() as neuron_model_path:
hub = huggingface_hub.HfApi()
if hub.repo_exists(neuron_model_id):
logger.info(f"Fetching {neuron_model_id} from the HuggingFace hub")
hub.snapshot_download(neuron_model_id, local_dir=neuron_model_path)
else:
export_model(model_id, export_kwargs, neuron_model_path)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.save_pretrained(neuron_model_path)
del tokenizer
# Create the test model on the hub
hub.create_repo(neuron_model_id, private=True)
hub.upload_folder(
folder_path=neuron_model_path,
repo_id=neuron_model_id,
ignore_patterns=[NeuronModelForCausalLM.CHECKPOINT_DIR + "/*"],
)
# Make sure it is cached
synchronize_hub_cache(cache_repo_id=OPTIMUM_CACHE_REPO_ID)
export_model(model_id, export_kwargs, neuron_model_path)
synchronize_hub_cache(cache_repo_id=OPTIMUM_CACHE_REPO_ID)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.save_pretrained(neuron_model_path)
del tokenizer
# Add dynamic parameters to the model configuration
model_config["neuron_model_path"] = neuron_model_path
model_config["neuron_model_id"] = neuron_model_id
# Also add model configuration name to allow tests to adapt their expectations
model_config["name"] = config_name
# Yield instead of returning to keep a reference to the temporary directory.
# It will go out of scope and be released only once all tests needing the fixture
# have been completed.
logger.info(f"{config_name} ready for testing ...")
os.environ["CUSTOM_CACHE_REPO"] = OPTIMUM_CACHE_REPO_ID
yield model_config
logger.info(f"Done with {config_name}")

View File

@ -0,0 +1,42 @@
import os
import pytest
from text_generation_server.generator import NeuronGenerator
from text_generation_server.model import fetch_model, is_cached
@pytest.fixture(scope="module")
def cached_model_id(neuron_model_config) -> str:
"""
Fixture to provide a cached model ID for testing.
This assumes the model is already cached in the local environment.
"""
export_kwargs = neuron_model_config["export_kwargs"]
os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"])
os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"])
os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"]
os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"])
yield neuron_model_config["model_id"]
os.environ.pop("MAX_BATCH_SIZE", None)
os.environ.pop("MAX_TOTAL_TOKENS", None)
os.environ.pop("HF_AUTO_CAST_TYPE", None)
os.environ.pop("HF_NUM_CORES", None)
def test_model_is_cached(cached_model_id):
assert is_cached(cached_model_id), f"Model {cached_model_id} is not cached"
def test_fetch_cached_model(cached_model_id: str):
model_path = fetch_model(cached_model_id)
assert os.path.exists(
model_path
), f"Model {cached_model_id} was not fetched successfully"
assert os.path.isdir(model_path), f"Model {cached_model_id} is not a directory"
def test_generator_from_cached_model(cached_model_id: str):
generator = NeuronGenerator.from_pretrained(model_id=cached_model_id)
assert generator is not None, "Generator could not be created from cached model"
assert generator.model is not None, "Generator model is not initialized"
assert generator.tokenizer is not None, "Generator tokenizer is not initialized"

View File

@ -9,13 +9,13 @@ def test_continuous_batching_two_requests(neuron_model_config):
"""
neuron_model_path = neuron_model_config["neuron_model_path"]
generator = NeuronGenerator.from_pretrained(neuron_model_path)
assert generator.model.batch_size > 1
assert generator.model.neuron_config.batch_size > 1
input_text = "Once upon a time"
max_new_tokens = 20
# Prefill a single request, remembering the generated token
tokens = {0: [], 1: []}
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens)
max_length = generator.model.max_length
max_length = generator.model.neuron_config.sequence_length
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
generations, next_batch = generator.prefill(batch)
assert next_batch.size == 1

View File

@ -23,7 +23,7 @@ def _test_decode(config_name, generator, do_sample):
request = create_request(
id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample
)
max_length = generator.model.max_length
max_length = generator.model.neuron_config.sequence_length
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
generations, next_batch = generator.prefill(batch)
# We already generated one token: call decode max_new_tokens - 1 times
@ -40,19 +40,15 @@ def _test_decode(config_name, generator, do_sample):
assert output.finish_reason == 0
if do_sample:
expected_text = {
"gpt2": " The sun was set",
"llama": "George Orwell, 1984",
"mistral": "The sky was",
"qwen2": " A young woman with",
"llama": " I sat alone in the café",
"qwen2": " The air was so still",
"granite": "1984, George Orwell",
}[config_name]
assert expected_text in output.text
else:
print(output.text)
expected_text = {
"gpt2": '\n\n"I\'m going to go to bed," I said.\n\n"I\'m going',
"llama": " George Orwells classic dystopian novel, 1984, begins with this ominous sentence. The story",
"mistral": "\nThe clocks were striking thirteen.\nThe clocks were striking thirteen.",
"llama": " The world was holding its breath as the world's top scientists and engineers gathered at the secret underground facility",
"qwen2": " I was sitting in my room, staring at the ceiling, when the door opened and in came a",
"granite": "\n\nThis opening line from George Orwell's dystopian novel \"198",
}[config_name]

View File

@ -9,7 +9,7 @@ def test_prefill(neuron_model_config):
neuron_model_path = neuron_model_config["neuron_model_path"]
generator = NeuronGenerator.from_pretrained(neuron_model_path)
max_batch_size = 4
assert generator.model.batch_size >= max_batch_size
assert generator.model.neuron_config.batch_size >= max_batch_size
for num_requests in [1, max_batch_size]:
for do_sample in [True, False]:
mode = "sample" if do_sample else "greedy"
@ -34,7 +34,7 @@ def _test_prefill(config_name, generator, batch_size, do_sample):
)
)
# Let's be pessimistic when estimating max_tokens
max_length = generator.model.max_length
max_length = generator.max_prefill_length()
batch = Batch(
id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
)
@ -46,17 +46,13 @@ def _test_prefill(config_name, generator, batch_size, do_sample):
assert len(generations) == batch_size
if do_sample:
expectations = {
"gpt2": [383, " The"],
"llama": [10058, " George"],
"mistral": [450, " The"],
"qwen2": [362, " A"],
"llama": [358, " I"],
"qwen2": [576, " The"],
"granite": [308, " ("],
}[config_name]
else:
expectations = {
"gpt2": [198, "\n"],
"llama": [10058, " George"],
"mistral": [13, "\n"],
"llama": [578, " The"],
"qwen2": [358, " I"],
"granite": [203, "\n"],
}[config_name]
@ -70,7 +66,7 @@ def test_prefill_truncate(neuron_model_config):
config_name = neuron_model_config["name"]
neuron_model_path = neuron_model_config["neuron_model_path"]
generator = NeuronGenerator.from_pretrained(neuron_model_path)
batch_size = generator.model.batch_size
batch_size = generator.model.neuron_config.batch_size
# We apply truncation to all requests but the first one
truncate = [
None,
@ -83,7 +79,7 @@ def test_prefill_truncate(neuron_model_config):
requests = []
for i in range(batch_size):
requests.append(create_request(id=i, inputs=input_text, truncate=truncate[i]))
max_length = generator.model.max_length
max_length = generator.max_prefill_length()
batch = Batch(
id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
)
@ -91,12 +87,12 @@ def test_prefill_truncate(neuron_model_config):
# Even if the input text is identical for all requests, the first generated token might
# be different because of the truncation
expectations = {
"gpt2": [" He", " He", "\n", " He"],
"llama": ["", " The", " He", " He"],
"mistral": [" He", "\n", " He", " He"],
"llama": [" He", "iens", "\x08", " He"],
"qwen2": [" He", " The", " He", " He"],
"granite": ["\n", "\n", " I", " He"],
}[config_name]
for i, g in enumerate(generations):
tokens = g.tokens
assert tokens.texts[0] == expectations[i]
assert (
tokens.texts[0] == expectations[i]
), f"Request {i} expected [{expectations[i]}], got [{tokens.texts[0]}]"

View File

@ -0,0 +1,63 @@
import os
import pytest
from tempfile import TemporaryDirectory
from optimum.neuron.models.inference.nxd.backend.config import NxDNeuronConfig
from optimum.neuron.utils import map_torch_dtype
from text_generation_server.tgi_env import (
get_neuron_config_for_model,
lookup_compatible_cached_model,
neuron_config_to_env,
)
def test_get_neuron_config_for_model(neuron_model_config):
neuron_model_path = neuron_model_config["neuron_model_path"]
export_kwargs = neuron_model_config["export_kwargs"]
os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"])
os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"])
os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"]
os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"])
neuron_config = get_neuron_config_for_model(neuron_model_path)
assert neuron_config is not None
assert neuron_config.batch_size == export_kwargs["batch_size"]
assert neuron_config.sequence_length == export_kwargs["sequence_length"]
assert neuron_config.tp_degree == export_kwargs["num_cores"]
if isinstance(neuron_config, NxDNeuronConfig):
assert map_torch_dtype(neuron_config.torch_dtype) == map_torch_dtype(
export_kwargs["auto_cast_type"]
)
else:
assert map_torch_dtype(neuron_config.auto_cast_type) == map_torch_dtype(
export_kwargs["auto_cast_type"]
)
@pytest.mark.parametrize("model_id", ["unsloth/Llama-3.2-1B-Instruct"])
def test_lookup_compatible_cached_model(model_id: str):
neuron_config = lookup_compatible_cached_model(model_id, None)
assert neuron_config is not None
def test_neuron_config_to_env(neuron_model_config) -> None:
neuron_model_path = neuron_model_config["neuron_model_path"]
neuron_config = get_neuron_config_for_model(neuron_model_path)
with TemporaryDirectory() as temp_dir:
os.environ["ENV_FILEPATH"] = os.path.join(temp_dir, "env.sh")
neuron_config_to_env(neuron_config)
with open(os.environ["ENV_FILEPATH"], "r") as env_file:
env_content = env_file.read()
assert f"export MAX_BATCH_SIZE={neuron_config.batch_size}" in env_content
assert (
f"export MAX_TOTAL_TOKENS={neuron_config.sequence_length}"
in env_content
)
assert f"export HF_NUM_CORES={neuron_config.tp_degree}" in env_content
if hasattr(neuron_config, "torch_dtype"):
auto_cast_type = str(map_torch_dtype(neuron_config.torch_dtype)).split(
"."
)[-1]
else:
auto_cast_type = neuron_config.auto_cast_type
assert f"export HF_AUTO_CAST_TYPE={auto_cast_type}" in env_content

View File

@ -9,7 +9,7 @@ touch $ENV_FILEPATH
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
${SCRIPT_DIR}/tgi_env.py $@
${SCRIPT_DIR}/tgi_entry_point.py $@
source $ENV_FILEPATH

View File

@ -0,0 +1,53 @@
#!/usr/bin/env python
import logging
import os
import sys
from text_generation_server.tgi_env import (
available_cores,
get_env_dict,
get_neuron_config_for_model,
neuron_config_to_env,
neuronxcc_version,
parse_cmdline_and_set_env,
tgi_env_vars,
)
logger = logging.getLogger(__name__)
def main():
"""
This script determines proper default TGI env variables for the neuron precompiled models to
work properly
:return:
"""
args = parse_cmdline_and_set_env()
for env_var in tgi_env_vars:
if not os.getenv(env_var):
break
else:
logger.info(
"All env vars %s already set, skipping, user know what they are doing",
tgi_env_vars,
)
sys.exit(0)
neuron_config = get_neuron_config_for_model(args.model_id, args.revision)
if not neuron_config:
msg = (
"No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}"
).format(get_env_dict(), available_cores, neuronxcc_version)
logger.error(msg)
raise Exception(msg)
neuron_config_to_env(neuron_config)
if __name__ == "__main__":
main()

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
"version": "3.3.1-dev0"
"version": "3.3.2-dev0"
},
"paths": {
"/": {

View File

@ -20,7 +20,7 @@ hf_token=YOUR_HF_ACCESS_TOKEN
docker run --runtime=habana --cap-add=sys_nice --ipc=host \
-p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \
ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \
--model-id $model
```
@ -52,7 +52,7 @@ hf_token=YOUR_ACCESS_TOKEN
docker run --runtime=habana --cap-add=sys_nice --ipc=host \
-p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \
ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \
--model-id $model
<text-generation-inference-launcher-arguments>
```
@ -115,7 +115,7 @@ docker run -p 8080:80 \
-e BATCH_BUCKET_SIZE=256 \
-e PREFILL_BATCH_BUCKET_SIZE=4 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \
ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 1024 --max-total-tokens 2048 \
@ -141,7 +141,7 @@ docker run -p 8080:80 \
-v $volume:/data \
-e PREFILL_BATCH_BUCKET_SIZE=1 \
-e BATCH_BUCKET_SIZE=1 \
ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \
ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \
--model-id $model \
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
--max-total-tokens 8192 --max-batch-size 4
@ -208,7 +208,7 @@ docker run --runtime=habana --ipc=host --cap-add=sys_nice \
-e PROF_PATH=/tmp/hpu_profile \
-e PROF_RANKS=0 \
-e PROF_RECORD_SHAPES=True \
ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \
ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \
--model-id $model
```

View File

@ -31,7 +31,7 @@ deployment instructions in the model card:
The service is launched simply by running the text-generation-inference container with two sets of parameters:
```
docker run <system_parameters> ghcr.io/huggingface/text-generation-inference:3.3.1-neuron <service_parameters>
docker run <system_parameters> ghcr.io/huggingface/text-generation-inference:3.3.2-neuron <service_parameters>
```
- system parameters are used to map ports, volumes and devices between the host and the service,

View File

@ -19,6 +19,6 @@ docker run --gpus all \
--shm-size 1g \
-e HF_TOKEN=$token \
-p 8080:80 \
-v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 \
-v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2 \
--model-id $model
```

View File

@ -19,7 +19,7 @@ bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models.
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model --quantize bitsandbytes
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model --quantize bitsandbytes
```
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
@ -27,7 +27,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model --quantize bitsandbytes-nf4
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model --quantize bitsandbytes-nf4
```
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
@ -48,7 +48,7 @@ $$({\hat{W}_{l}}^{*} = argmin_{\hat{W_{l}}} ||W_{l}X-\hat{W}_{l}X||^{2}_{2})$$
TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇
```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model --quantize gptq
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model --quantize gptq
```
Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI.

View File

@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--device=/dev/kfd --device=/dev/dri --group-add video \
--ipc=host --shm-size 256g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.3.1-rocm \
ghcr.io/huggingface/text-generation-inference:3.3.2-rocm \
--model-id $model
```

View File

@ -12,7 +12,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.3.1-intel-xpu \
ghcr.io/huggingface/text-generation-inference:3.3.2-intel-xpu \
--model-id $model --cuda-graphs 0
```
@ -29,7 +29,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.3.1-intel-cpu \
ghcr.io/huggingface/text-generation-inference:3.3.2-intel-cpu \
--model-id $model --cuda-graphs 0
```

View File

@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.3.1 \
ghcr.io/huggingface/text-generation-inference:3.3.2 \
--model-id $model
```

View File

@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.3.1 \
ghcr.io/huggingface/text-generation-inference:3.3.2 \
--model-id $model
```
@ -96,7 +96,7 @@ curl 127.0.0.1:8080/generate \
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
```bash
docker run ghcr.io/huggingface/text-generation-inference:3.3.1 --help
docker run ghcr.io/huggingface/text-generation-inference:3.3.2 --help
```
</Tip>

View File

@ -163,7 +163,7 @@ hub = {
# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
image_uri=get_huggingface_llm_image_uri("huggingface",version="3.3.1"),
image_uri=get_huggingface_llm_image_uri("huggingface",version="3.3.2"),
env=hub,
role=role,
)

View File

@ -28,15 +28,6 @@ logger = logging.getLogger(__file__)
# All model configurations below will be added to the neuron_model_config fixture
MODEL_CONFIGURATIONS = {
"gpt2": {
"model_id": "gpt2",
"export_kwargs": {
"batch_size": 4,
"sequence_length": 1024,
"num_cores": 2,
"auto_cast_type": "fp16",
},
},
"llama": {
"model_id": "unsloth/Llama-3.2-1B-Instruct",
"export_kwargs": {
@ -46,15 +37,6 @@ MODEL_CONFIGURATIONS = {
"auto_cast_type": "fp16",
},
},
"mistral": {
"model_id": "optimum/mistral-1.1b-testing",
"export_kwargs": {
"batch_size": 4,
"sequence_length": 4096,
"num_cores": 2,
"auto_cast_type": "bf16",
},
},
"qwen2": {
"model_id": "Qwen/Qwen2.5-0.5B",
"export_kwargs": {

View File

@ -17,7 +17,7 @@
"id": "",
"model": "google/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.3.1-dev0-native",
"system_fingerprint": "3.3.2-dev0-native",
"usage": {
"completion_tokens": 42,
"prompt_tokens": 277,

View File

@ -17,7 +17,7 @@
"id": "",
"model": "google/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.3.1-dev0-native",
"system_fingerprint": "3.3.2-dev0-native",
"usage": {
"completion_tokens": 62,
"prompt_tokens": 277,

View File

@ -17,7 +17,7 @@
"id": "",
"model": "google/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.3.1-dev0-native",
"system_fingerprint": "3.3.2-dev0-native",
"usage": {
"completion_tokens": 67,
"prompt_tokens": 277,

View File

@ -17,7 +17,7 @@
"id": "",
"model": "google/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.3.1-dev0-native",
"system_fingerprint": "3.3.2-dev0-native",
"usage": {
"completion_tokens": 72,
"prompt_tokens": 275,

View File

@ -17,7 +17,7 @@
"id": "",
"model": "google/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.3.1-dev0-native",
"system_fingerprint": "3.3.2-dev0-native",
"usage": {
"completion_tokens": 80,
"prompt_tokens": 279,

View File

@ -14,7 +14,7 @@
"id": "",
"model": "google/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.3.1-dev0-native",
"system_fingerprint": "3.3.2-dev0-native",
"usage": {
"completion_tokens": 35,
"prompt_tokens": 32,

View File

@ -14,7 +14,7 @@
"id": "",
"model": "google/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.3.1-dev0-native",
"system_fingerprint": "3.3.2-dev0-native",
"usage": {
"completion_tokens": 44,
"prompt_tokens": 37,

View File

@ -18,7 +18,7 @@
"id": "",
"model": "unsloth/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.3.1-dev0-native",
"system_fingerprint": "3.3.2-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 45,
@ -44,7 +44,7 @@
"id": "",
"model": "unsloth/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.3.1-dev0-native",
"system_fingerprint": "3.3.2-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 45,

View File

@ -17,7 +17,7 @@
"id": "",
"model": "unsloth/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.3.1-dev0-native",
"system_fingerprint": "3.3.2-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 45,

View File

@ -20,9 +20,7 @@ async def test_model_single_request(tgi_service):
)
assert response.details.generated_tokens == 17
greedy_expectations = {
"gpt2": "\n\nDeep learning is a new field of research that has been around for a while",
"llama": " and How Does it Work?\nDeep learning is a subset of machine learning that uses artificial",
"mistral": "\nWhat is Deep Learning?\nDeep Learning is a type of machine learning that",
"llama": " and how does it work?\nDeep learning is a subset of machine learning that uses artificial",
"qwen2": " - Part 1\n\nDeep Learning is a subset of Machine Learning that is based on",
"granite": "\n\nDeep Learning is a subset of Machine Learning, which is a branch of Art",
}
@ -79,9 +77,7 @@ async def test_model_multiple_requests(tgi_service, neuron_generate_load):
assert len(responses) == 4
expectations = {
"gpt2": "Deep learning is a new field of research that has been around for a while",
"llama": "Deep learning is a subset of machine learning that uses artificial",
"mistral": "Deep Learning is a type of machine learning that",
"qwen2": "Deep Learning is a subset of Machine Learning that is based on",
"granite": "Deep Learning is a subset of Machine Learning, which is a branch of Art",
}

View File

@ -27,10 +27,6 @@ impl Env {
docker_label: option_env!("DOCKER_LABEL").unwrap_or("N/A"),
}
}
pub fn should_start_a_single_hpu_shard(&self) -> bool {
self.hpu_env != "N/A" && std::env::var("ATTENTION").as_deref() != Ok("paged")
}
}
impl fmt::Display for Env {

View File

@ -1590,11 +1590,6 @@ fn spawn_shards(
) -> Result<(), LauncherError> {
// Start shard processes
for rank in 0..num_shard {
if rank != 0 && env_runtime::Env::new().should_start_a_single_hpu_shard() {
tracing::info!("Running on HPU, the launcher will not do any sharding as actual sharding is done in the server");
break;
}
let model_id = args.model_id.clone();
let revision = args.revision.clone();
let uds_path = args.shard_uds_path.clone();
@ -1670,10 +1665,6 @@ fn spawn_shards(
if shard_ready == num_shard {
break;
}
if env_runtime::Env::new().should_start_a_single_hpu_shard() {
tracing::info!("HPU detected, shard is ready");
break;
}
}
Err(TryRecvError::Empty) => {
sleep(Duration::from_millis(100));