mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 03:14:53 +00:00
Merge branch 'huggingface:main' into qwen3_moe
This commit is contained in:
commit
1791c855f0
16
Cargo.lock
generated
16
Cargo.lock
generated
@ -4650,7 +4650,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-backends-trtllm"
|
||||
version = "3.3.1-dev0"
|
||||
version = "3.3.2-dev0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"clap 4.5.32",
|
||||
@ -4671,7 +4671,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-benchmark"
|
||||
version = "3.3.1-dev0"
|
||||
version = "3.3.2-dev0"
|
||||
dependencies = [
|
||||
"average",
|
||||
"clap 4.5.32",
|
||||
@ -4691,7 +4691,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-client"
|
||||
version = "3.3.1-dev0"
|
||||
version = "3.3.2-dev0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
@ -4709,7 +4709,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-launcher"
|
||||
version = "3.3.1-dev0"
|
||||
version = "3.3.2-dev0"
|
||||
dependencies = [
|
||||
"clap 4.5.32",
|
||||
"ctrlc",
|
||||
@ -4730,7 +4730,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router"
|
||||
version = "3.3.1-dev0"
|
||||
version = "3.3.2-dev0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-stream",
|
||||
@ -4782,7 +4782,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router-llamacpp"
|
||||
version = "3.3.1-dev0"
|
||||
version = "3.3.2-dev0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bindgen 0.71.1",
|
||||
@ -4800,7 +4800,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router-v2"
|
||||
version = "3.3.1-dev0"
|
||||
version = "3.3.2-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
@ -4849,7 +4849,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router-v3"
|
||||
version = "3.3.1-dev0"
|
||||
version = "3.3.2-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
|
@ -21,7 +21,7 @@ default-members = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "3.3.1-dev0"
|
||||
version = "3.3.2-dev0"
|
||||
edition = "2021"
|
||||
authors = ["Olivier Dehaene"]
|
||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||
|
@ -5,7 +5,7 @@ RUN mkdir -p /tgi
|
||||
# Fetch the optimum-neuron sources directly to avoid relying on pypi deployments
|
||||
FROM alpine AS optimum-neuron
|
||||
RUN mkdir -p /optimum-neuron
|
||||
ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.1.0.tar.gz /optimum-neuron/sources.tar.gz
|
||||
ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.2.0.tar.gz /optimum-neuron/sources.tar.gz
|
||||
RUN tar -C /optimum-neuron -xf /optimum-neuron/sources.tar.gz --strip-components=1
|
||||
|
||||
# Build cargo components (adapted from TGI original Dockerfile)
|
||||
@ -108,10 +108,10 @@ RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEU
|
||||
# Install neuronx packages
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
aws-neuronx-dkms=2.19.64.0 \
|
||||
aws-neuronx-collectives=2.23.135.0-3e70920f2 \
|
||||
aws-neuronx-runtime-lib=2.23.112.0-9b5179492 \
|
||||
aws-neuronx-tools=2.20.204.0 \
|
||||
aws-neuronx-dkms=2.20.28.0 \
|
||||
aws-neuronx-collectives=2.24.59.0-838c7fc8b \
|
||||
aws-neuronx-runtime-lib=2.24.53.0-f239092cc \
|
||||
aws-neuronx-tools=2.22.61.0 \
|
||||
libxml2 \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& apt-get clean
|
||||
@ -125,11 +125,10 @@ RUN pip3 install \
|
||||
--index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
RUN pip3 install \
|
||||
neuronx-cc==2.16.372.0 \
|
||||
torch-neuronx==2.5.1.2.4.0 \
|
||||
transformers-neuronx==0.13.322 \
|
||||
neuronx-distributed==0.10.1 \
|
||||
libneuronxla==2.1.681.0 \
|
||||
neuronx-cc==2.17.194.0 \
|
||||
torch-neuronx==2.5.1.2.6.0 \
|
||||
neuronx-distributed==0.11.0 \
|
||||
libneuronxla==2.2.1630.0 \
|
||||
--extra-index-url=https://pip.repos.neuron.amazonaws.com
|
||||
|
||||
# Install HuggingFace packages
|
||||
@ -160,7 +159,7 @@ RUN pip install dist/text_generation_server*.tar.gz
|
||||
# Final image
|
||||
FROM neuron
|
||||
|
||||
COPY backends/neuron/tgi_env.py /tgi_env.py
|
||||
COPY backends/neuron/tgi_entry_point.py /tgi_entry_point.py
|
||||
COPY backends/neuron/tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||
RUN chmod +x /tgi-entrypoint.sh
|
||||
|
||||
|
@ -57,7 +57,7 @@ ARG PYTORCH_VERSION
|
||||
|
||||
FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base
|
||||
|
||||
ENV ATTENTION=default
|
||||
ENV ATTENTION=paged
|
||||
ENV PREFIX_CACHING=0
|
||||
ENV PREFILL_CHUNKING=0
|
||||
ENV PT_HPU_LAZY_MODE=1
|
||||
|
@ -84,7 +84,7 @@ model=HuggingFaceH4/zephyr-7b-beta
|
||||
volume=$PWD/data
|
||||
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model
|
||||
```
|
||||
|
||||
And then you can make requests like
|
||||
@ -121,7 +121,7 @@ curl localhost:8080/v1/chat/completions \
|
||||
|
||||
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
||||
|
||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1-rocm --model-id $model` instead of the command above.
|
||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2-rocm --model-id $model` instead of the command above.
|
||||
|
||||
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
|
||||
```
|
||||
@ -152,7 +152,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||
token=<your cli READ token>
|
||||
|
||||
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model
|
||||
```
|
||||
|
||||
### A note on Shared Memory (shm)
|
||||
|
250
backends/gaudi/server/poetry.lock
generated
250
backends/gaudi/server/poetry.lock
generated
@ -1058,199 +1058,6 @@ files = [
|
||||
{file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-cublas-cu12"
|
||||
version = "12.4.5.8"
|
||||
description = "CUBLAS native runtime libraries"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
groups = ["main"]
|
||||
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
|
||||
files = [
|
||||
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3"},
|
||||
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b"},
|
||||
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-win_amd64.whl", hash = "sha256:5a796786da89203a0657eda402bcdcec6180254a8ac22d72213abc42069522dc"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-cuda-cupti-cu12"
|
||||
version = "12.4.127"
|
||||
description = "CUDA profiling tools runtime libs."
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
groups = ["main"]
|
||||
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
|
||||
files = [
|
||||
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a"},
|
||||
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb"},
|
||||
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-cuda-nvrtc-cu12"
|
||||
version = "12.4.127"
|
||||
description = "NVRTC native runtime libraries"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
groups = ["main"]
|
||||
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
|
||||
files = [
|
||||
{file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0eedf14185e04b76aa05b1fea04133e59f465b6f960c0cbf4e37c3cb6b0ea198"},
|
||||
{file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338"},
|
||||
{file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:a961b2f1d5f17b14867c619ceb99ef6fcec12e46612711bcec78eb05068a60ec"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-cuda-runtime-cu12"
|
||||
version = "12.4.127"
|
||||
description = "CUDA Runtime native Libraries"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
groups = ["main"]
|
||||
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
|
||||
files = [
|
||||
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3"},
|
||||
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5"},
|
||||
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:09c2e35f48359752dfa822c09918211844a3d93c100a715d79b59591130c5e1e"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-cudnn-cu12"
|
||||
version = "9.1.0.70"
|
||||
description = "cuDNN runtime libraries"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
groups = ["main"]
|
||||
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
|
||||
files = [
|
||||
{file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f"},
|
||||
{file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
nvidia-cublas-cu12 = "*"
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-cufft-cu12"
|
||||
version = "11.2.1.3"
|
||||
description = "CUFFT native runtime libraries"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
groups = ["main"]
|
||||
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
|
||||
files = [
|
||||
{file = "nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399"},
|
||||
{file = "nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9"},
|
||||
{file = "nvidia_cufft_cu12-11.2.1.3-py3-none-win_amd64.whl", hash = "sha256:d802f4954291101186078ccbe22fc285a902136f974d369540fd4a5333d1440b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
nvidia-nvjitlink-cu12 = "*"
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-curand-cu12"
|
||||
version = "10.3.5.147"
|
||||
description = "CURAND native runtime libraries"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
groups = ["main"]
|
||||
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
|
||||
files = [
|
||||
{file = "nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1f173f09e3e3c76ab084aba0de819c49e56614feae5c12f69883f4ae9bb5fad9"},
|
||||
{file = "nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b"},
|
||||
{file = "nvidia_curand_cu12-10.3.5.147-py3-none-win_amd64.whl", hash = "sha256:f307cc191f96efe9e8f05a87096abc20d08845a841889ef78cb06924437f6771"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-cusolver-cu12"
|
||||
version = "11.6.1.9"
|
||||
description = "CUDA solver native runtime libraries"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
groups = ["main"]
|
||||
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
|
||||
files = [
|
||||
{file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e"},
|
||||
{file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260"},
|
||||
{file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-win_amd64.whl", hash = "sha256:e77314c9d7b694fcebc84f58989f3aa4fb4cb442f12ca1a9bde50f5e8f6d1b9c"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
nvidia-cublas-cu12 = "*"
|
||||
nvidia-cusparse-cu12 = "*"
|
||||
nvidia-nvjitlink-cu12 = "*"
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-cusparse-cu12"
|
||||
version = "12.3.1.170"
|
||||
description = "CUSPARSE native runtime libraries"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
groups = ["main"]
|
||||
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
|
||||
files = [
|
||||
{file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3"},
|
||||
{file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1"},
|
||||
{file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-win_amd64.whl", hash = "sha256:9bc90fb087bc7b4c15641521f31c0371e9a612fc2ba12c338d3ae032e6b6797f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
nvidia-nvjitlink-cu12 = "*"
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-cusparselt-cu12"
|
||||
version = "0.6.2"
|
||||
description = "NVIDIA cuSPARSELt"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
groups = ["main"]
|
||||
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
|
||||
files = [
|
||||
{file = "nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:067a7f6d03ea0d4841c85f0c6f1991c5dda98211f6302cb83a4ab234ee95bef8"},
|
||||
{file = "nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:df2c24502fd76ebafe7457dbc4716b2fec071aabaed4fb7691a201cde03704d9"},
|
||||
{file = "nvidia_cusparselt_cu12-0.6.2-py3-none-win_amd64.whl", hash = "sha256:0057c91d230703924c0422feabe4ce768841f9b4b44d28586b6f6d2eb86fbe70"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-nccl-cu12"
|
||||
version = "2.21.5"
|
||||
description = "NVIDIA Collective Communication Library (NCCL) Runtime"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
groups = ["main"]
|
||||
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
|
||||
files = [
|
||||
{file = "nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-nvjitlink-cu12"
|
||||
version = "12.4.127"
|
||||
description = "Nvidia JIT LTO Library"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
groups = ["main"]
|
||||
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
|
||||
files = [
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-nvtx-cu12"
|
||||
version = "12.4.127"
|
||||
description = "NVIDIA Tools Extension"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
groups = ["main"]
|
||||
markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
|
||||
files = [
|
||||
{file = "nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7959ad635db13edf4fc65c06a6e9f9e55fc2f92596db928d169c0bb031e88ef3"},
|
||||
{file = "nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a"},
|
||||
{file = "nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-api"
|
||||
version = "1.32.0"
|
||||
@ -2650,63 +2457,6 @@ files = [
|
||||
{file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.6.0"
|
||||
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
|
||||
optional = false
|
||||
python-versions = ">=3.9.0"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "torch-2.6.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:6860df13d9911ac158f4c44031609700e1eba07916fff62e21e6ffa0a9e01961"},
|
||||
{file = "torch-2.6.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:c4f103a49830ce4c7561ef4434cc7926e5a5fe4e5eb100c19ab36ea1e2b634ab"},
|
||||
{file = "torch-2.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:56eeaf2ecac90da5d9e35f7f35eb286da82673ec3c582e310a8d1631a1c02341"},
|
||||
{file = "torch-2.6.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:09e06f9949e1a0518c5b09fe95295bc9661f219d9ecb6f9893e5123e10696628"},
|
||||
{file = "torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:7979834102cd5b7a43cc64e87f2f3b14bd0e1458f06e9f88ffa386d07c7446e1"},
|
||||
{file = "torch-2.6.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:ccbd0320411fe1a3b3fec7b4d3185aa7d0c52adac94480ab024b5c8f74a0bf1d"},
|
||||
{file = "torch-2.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:46763dcb051180ce1ed23d1891d9b1598e07d051ce4c9d14307029809c4d64f7"},
|
||||
{file = "torch-2.6.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:94fc63b3b4bedd327af588696559f68c264440e2503cc9e6954019473d74ae21"},
|
||||
{file = "torch-2.6.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:2bb8987f3bb1ef2675897034402373ddfc8f5ef0e156e2d8cfc47cacafdda4a9"},
|
||||
{file = "torch-2.6.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:b789069020c5588c70d5c2158ac0aa23fd24a028f34a8b4fcb8fcb4d7efcf5fb"},
|
||||
{file = "torch-2.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:7e1448426d0ba3620408218b50aa6ada88aeae34f7a239ba5431f6c8774b1239"},
|
||||
{file = "torch-2.6.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:9a610afe216a85a8b9bc9f8365ed561535c93e804c2a317ef7fabcc5deda0989"},
|
||||
{file = "torch-2.6.0-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:4874a73507a300a5d089ceaff616a569e7bb7c613c56f37f63ec3ffac65259cf"},
|
||||
{file = "torch-2.6.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:a0d5e1b9874c1a6c25556840ab8920569a7a4137afa8a63a32cee0bc7d89bd4b"},
|
||||
{file = "torch-2.6.0-cp313-cp313-win_amd64.whl", hash = "sha256:510c73251bee9ba02ae1cb6c9d4ee0907b3ce6020e62784e2d7598e0cfa4d6cc"},
|
||||
{file = "torch-2.6.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:ff96f4038f8af9f7ec4231710ed4549da1bdebad95923953a25045dcf6fd87e2"},
|
||||
{file = "torch-2.6.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:9ea955317cfcd3852b1402b62af258ce735c2edeee42ca9419b6bc889e5ae053"},
|
||||
{file = "torch-2.6.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:bb2c6c3e65049f081940f5ab15c9136c7de40d3f01192541c920a07c7c585b7e"},
|
||||
{file = "torch-2.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:683410f97984103148e31b38a8631acf31c3034c020c0f4d26171e7626d8317a"},
|
||||
{file = "torch-2.6.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:265f70de5fd45b864d924b64be1797f86e76c8e48a02c2a3a6fc7ec247d2226c"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
filelock = "*"
|
||||
fsspec = "*"
|
||||
jinja2 = "*"
|
||||
networkx = "*"
|
||||
nvidia-cublas-cu12 = {version = "12.4.5.8", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-cuda-cupti-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-cuda-nvrtc-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-cuda-runtime-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-cudnn-cu12 = {version = "9.1.0.70", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-cufft-cu12 = {version = "11.2.1.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-curand-cu12 = {version = "10.3.5.147", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-cusolver-cu12 = {version = "11.6.1.9", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-cusparse-cu12 = {version = "12.3.1.170", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-cusparselt-cu12 = {version = "0.6.2", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-nccl-cu12 = {version = "2.21.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-nvjitlink-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-nvtx-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
setuptools = {version = "*", markers = "python_version >= \"3.12\""}
|
||||
sympy = {version = "1.13.1", markers = "python_version >= \"3.9\""}
|
||||
triton = {version = "3.2.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
typing-extensions = ">=4.10.0"
|
||||
|
||||
[package.extras]
|
||||
opt-einsum = ["opt-einsum (>=3.3)"]
|
||||
optree = ["optree (>=0.13.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "tqdm"
|
||||
version = "4.67.1"
|
||||
|
@ -22,10 +22,9 @@ opentelemetry-instrumentation-grpc = "^0.53b0"
|
||||
hf-transfer = "^0.1.9"
|
||||
sentencepiece = "^0.2.0"
|
||||
peft = "^0.15"
|
||||
optimum-habana = "1.17"
|
||||
transformers = "^4.49"
|
||||
transformers = "^4.52.4"
|
||||
numpy = "^1.26"
|
||||
accelerate = "^0.33"
|
||||
accelerate = "^1.7.0"
|
||||
outlines= { version = "^0.0.36", optional = true }
|
||||
prometheus-client = "^0.21.1"
|
||||
py-cpuinfo = "^9.0.0"
|
||||
|
@ -1,4 +1,4 @@
|
||||
accelerate==0.33.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
accelerate==1.7.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
annotated-types==0.7.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
attrs==25.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2025.1.31 ; python_version >= "3.9" and python_version < "3.13"
|
||||
@ -36,19 +36,6 @@ nest-asyncio==1.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numba==0.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
nvidia-cublas-cu12==12.4.5.8 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
nvidia-cuda-cupti-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
nvidia-cuda-nvrtc-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
nvidia-cuda-runtime-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
nvidia-cudnn-cu12==9.1.0.70 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
nvidia-cufft-cu12==11.2.1.3 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
nvidia-curand-cu12==10.3.5.147 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
nvidia-cusolver-cu12==11.6.1.9 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
nvidia-cusparse-cu12==12.3.1.170 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
nvidia-cusparselt-cu12==0.6.2 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
nvidia-nccl-cu12==2.21.5 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
nvidia-nvjitlink-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
nvidia-nvtx-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
opentelemetry-api==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-common==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
@ -59,7 +46,6 @@ opentelemetry-instrumentation==0.53b0 ; python_version >= "3.9" and python_versi
|
||||
opentelemetry-proto==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-sdk==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
optimum-habana==1.17.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
optimum==1.24.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
outlines==0.0.36 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
@ -88,9 +74,8 @@ shellingham==1.5.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sympy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
threadpoolctl==3.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
torch==2.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.67.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.49.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.52.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
triton==3.2.0 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
typer==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.13.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -1,6 +1,4 @@
|
||||
import os
|
||||
import psutil
|
||||
import signal
|
||||
import sys
|
||||
import typer
|
||||
|
||||
@ -115,80 +113,19 @@ def serve(
|
||||
raise RuntimeError(
|
||||
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||
)
|
||||
|
||||
logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype))
|
||||
|
||||
if sharded and os.getenv("ATTENTION", "default") not in {"paged"}:
|
||||
tgi_file = Path(__file__).resolve().parent / "tgi_service.py"
|
||||
num_shard = int(os.getenv("WORLD_SIZE", "1"))
|
||||
logger.info("CLI SHARDED = {}".format(num_shard))
|
||||
import subprocess
|
||||
|
||||
cmd = (
|
||||
f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}"
|
||||
)
|
||||
cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}"
|
||||
cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}"
|
||||
cmd += f" --quantize {quantize} --max_input_tokens {max_input_tokens}"
|
||||
if speculate is not None:
|
||||
cmd += f"--speculate {speculate}"
|
||||
logger.info("CLI server start deepspeed ={} ".format(cmd))
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc:
|
||||
do_terminate = False
|
||||
current_handler = signal.getsignal(signal.SIGTERM)
|
||||
|
||||
def terminate_handler(sig, frame):
|
||||
nonlocal do_terminate
|
||||
do_terminate = True
|
||||
if callable(current_handler):
|
||||
current_handler(sig, frame)
|
||||
|
||||
signal.signal(signal.SIGTERM, terminate_handler)
|
||||
|
||||
finished = False
|
||||
while not finished:
|
||||
try:
|
||||
if do_terminate:
|
||||
parent = psutil.Process(proc.pid)
|
||||
all_procs = parent.children(recursive=True) + [parent]
|
||||
for p in all_procs:
|
||||
try:
|
||||
p.terminate()
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
_, alive = psutil.wait_procs(all_procs, timeout=30)
|
||||
for p in alive:
|
||||
p.kill()
|
||||
|
||||
do_terminate = False
|
||||
|
||||
proc.wait(timeout=3)
|
||||
except subprocess.TimeoutExpired:
|
||||
pass
|
||||
else:
|
||||
finished = True
|
||||
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
if proc.returncode != 0:
|
||||
logger.error(f"{cmd} exited with status = {proc.returncode}")
|
||||
return proc.returncode
|
||||
else:
|
||||
server.serve(
|
||||
model_id,
|
||||
lora_adapters,
|
||||
revision,
|
||||
sharded,
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
uds_path,
|
||||
max_input_tokens,
|
||||
)
|
||||
server.serve(
|
||||
model_id,
|
||||
lora_adapters,
|
||||
revision,
|
||||
sharded,
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
uds_path,
|
||||
max_input_tokens,
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
|
@ -1,53 +0,0 @@
|
||||
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||
|
||||
import os
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
quant_config = os.getenv("QUANT_CONFIG", "")
|
||||
is_quantization_enabled = quant_config != ""
|
||||
|
||||
if is_quantization_enabled:
|
||||
os.environ.setdefault("ENABLE_EXPERIMENTAL_FLAGS", "true")
|
||||
os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true")
|
||||
os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false")
|
||||
os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false")
|
||||
os.environ.setdefault("UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av")
|
||||
os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
|
||||
|
||||
|
||||
def patch_scoped_linear_all_reduce(model):
|
||||
from deepspeed.module_inject.layers import LinearAllreduce
|
||||
from optimum.habana.transformers.models.modeling_all_models import (
|
||||
ScopedLinearAllReduce,
|
||||
)
|
||||
|
||||
for name, module in model.named_children():
|
||||
if type(module) is LinearAllreduce:
|
||||
SL = ScopedLinearAllReduce(mod=module)
|
||||
setattr(model, name, SL)
|
||||
patch_scoped_linear_all_reduce(module)
|
||||
|
||||
|
||||
def setup_quantization(model):
|
||||
if is_quantization_enabled:
|
||||
htorch.core.quantization._mark_params_as_const(model)
|
||||
htorch.core.quantization._check_params_as_const(model)
|
||||
htorch.core.hpu_initialize(model)
|
||||
return model
|
||||
|
||||
|
||||
def prepare_model_for_quantization(model):
|
||||
if is_quantization_enabled:
|
||||
if model.config.model_type in [
|
||||
"llama",
|
||||
"falcon",
|
||||
"qwen2",
|
||||
"starcoder2",
|
||||
"gemma",
|
||||
]:
|
||||
patch_scoped_linear_all_reduce(model)
|
||||
from neural_compressor.torch.quantization import FP8Config, convert
|
||||
|
||||
config = FP8Config.from_json_file(quant_config)
|
||||
model = convert(model, config)
|
||||
return model
|
@ -11,6 +11,7 @@ from .hpu import (
|
||||
attention,
|
||||
paged_attention,
|
||||
paged_attention_mla,
|
||||
set_block_mapping,
|
||||
)
|
||||
|
||||
|
||||
@ -22,6 +23,7 @@ __all__ = [
|
||||
"get_kv_scales",
|
||||
"paged_attention",
|
||||
"paged_attention_mla",
|
||||
"set_block_mapping",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"KVCache",
|
||||
"KVCompressCache",
|
||||
|
@ -8,6 +8,7 @@ from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
||||
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
||||
import os
|
||||
from text_generation_server.models.globals import BLOCK_SIZE
|
||||
import math
|
||||
|
||||
SUPPORTS_WINDOWING = False
|
||||
|
||||
@ -106,6 +107,21 @@ def attention(
|
||||
return attn_output
|
||||
|
||||
|
||||
def set_block_mapping(hpu_attention_meta: HPUPagedAttentionMetadata, batch_size):
|
||||
block_mapping = torch.nn.functional.one_hot(
|
||||
hpu_attention_meta.block_groups, num_classes=batch_size
|
||||
)
|
||||
dtype = hpu_attention_meta.block_usage.dtype
|
||||
device = hpu_attention_meta.block_usage.device
|
||||
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
|
||||
mask = mask >= hpu_attention_meta.block_usage.unsqueeze(-1)
|
||||
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
|
||||
hpu_attention_meta = hpu_attention_meta._replace(
|
||||
attn_bias=attn_bias, block_mapping=block_mapping.to(dtype)
|
||||
)
|
||||
return hpu_attention_meta
|
||||
|
||||
|
||||
def paged_attention(
|
||||
query: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
@ -176,4 +192,10 @@ def paged_attention_mla(
|
||||
return output.view(batch_size, head_num, -1)
|
||||
|
||||
|
||||
__all__ = ["SUPPORTS_WINDOWING", "attention", "paged_attention", "paged_attention_mla"]
|
||||
__all__ = [
|
||||
"SUPPORTS_WINDOWING",
|
||||
"attention",
|
||||
"paged_attention",
|
||||
"paged_attention_mla",
|
||||
"set_block_mapping",
|
||||
]
|
||||
|
@ -36,9 +36,7 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
self._sin_k_cached = None
|
||||
self.scaling_factor = scaling_factor
|
||||
self.dynamic_args = None
|
||||
self._update_cos_sin_cache(
|
||||
torch.float32, inv_freq.device, max_position_embeddings
|
||||
)
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -270,7 +268,9 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
def get_cos_sin(self, position_ids: torch.Tensor):
|
||||
|
||||
self._update_cos_sin_cache(
|
||||
torch.float32, position_ids.device, seqlen=self.max_position_embeddings
|
||||
)
|
||||
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||
|
||||
@ -298,9 +298,6 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
self.dynamic_args = None
|
||||
self._update_cos_sin_cache(
|
||||
torch.float32, short_inv_freq.device, max_position_embeddings
|
||||
)
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
@ -354,9 +351,6 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
self.dynamic_args = None
|
||||
self._update_cos_sin_cache(
|
||||
torch.float32, short_inv_freq.device, max_position_embeddings
|
||||
)
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
if (
|
||||
@ -598,6 +592,9 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
||||
position_ids: torch.Tensor,
|
||||
):
|
||||
slen = position_ids.shape[0]
|
||||
self._update_cos_sin_cache(
|
||||
torch.float32, position_ids.device, seqlen=self.max_position_embeddings
|
||||
)
|
||||
|
||||
cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
|
||||
sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])
|
||||
|
@ -5,7 +5,6 @@ import os
|
||||
|
||||
from loguru import logger
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.auto import modeling_auto
|
||||
from huggingface_hub import hf_hub_download, HfApi
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
@ -36,14 +35,10 @@ __all__ = [
|
||||
"Seq2SeqLM",
|
||||
"get_model_with_lora_adapters",
|
||||
]
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
|
||||
VLM_BATCH_TYPES = set()
|
||||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||
|
||||
FLASH_ATTENTION = False
|
||||
if ATTENTION == "paged":
|
||||
FLASH_ATTENTION = True
|
||||
FLASH_ATTENTION = True
|
||||
|
||||
try:
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
@ -83,9 +78,6 @@ try:
|
||||
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||
FlashGPTNeoXForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.pali_gemma import (
|
||||
PaliGemmaBatch,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
||||
PaliGemmaForConditionalGeneration,
|
||||
)
|
||||
@ -156,7 +148,6 @@ if FLASH_ATTENTION:
|
||||
)
|
||||
|
||||
VLM_BATCH_TYPES = {
|
||||
PaliGemmaBatch,
|
||||
FlashVlmCausalLMBatch,
|
||||
FlashMllamaCausalLMBatch,
|
||||
}
|
||||
@ -642,6 +633,7 @@ def get_model(
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
support_chunking=False,
|
||||
)
|
||||
elif model_type == BAICHUAN:
|
||||
return FlashCausalLM(
|
||||
@ -791,6 +783,8 @@ def get_model(
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
# TODO: Fix bug in rust image_text_replacement implementation
|
||||
support_chunking=False,
|
||||
)
|
||||
elif model_type == QWEN2_5_VL:
|
||||
return FlashVlmCausalLM(
|
||||
@ -806,6 +800,8 @@ def get_model(
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=Qwen2_5_VLConfig,
|
||||
processor_class=Qwen2_5_VLProcessor,
|
||||
# TODO: Fix bug in rust image_text_replacement implementation
|
||||
support_chunking=False,
|
||||
)
|
||||
elif model_type == QWEN3:
|
||||
return FlashCausalLM(
|
||||
@ -843,6 +839,7 @@ def get_model(
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
support_chunking=False,
|
||||
)
|
||||
elif model_type == IDEFICS2:
|
||||
return FlashVlmCausalLM(
|
||||
@ -887,7 +884,6 @@ def get_model(
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
batch_class=PaliGemmaBatch,
|
||||
)
|
||||
elif model_type == LLAVA_NEXT:
|
||||
return FlashVlmCausalLM(
|
||||
@ -901,72 +897,6 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
from text_generation_server.models.causal_lm import CausalLM
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
from text_generation_server.models.custom_modeling.mllama import (
|
||||
MllamaForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.llava_next import (
|
||||
LlavaNextForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.models.vlm_causal_lm import (
|
||||
VlmCausalLMBatch,
|
||||
)
|
||||
|
||||
VLM_BATCH_TYPES.add(VlmCausalLMBatch)
|
||||
|
||||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||
|
||||
adapt_transformers_to_gaudi()
|
||||
if SDP_ON_BF16 == 1:
|
||||
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
|
||||
if model_type == "gpt_bigcode":
|
||||
from text_generation_server.models.starcoder import StarCoder
|
||||
|
||||
return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
|
||||
if model_type == "bloom":
|
||||
from text_generation_server.models.bloom import BLOOM
|
||||
|
||||
return BLOOM(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "llava_next":
|
||||
return VlmCausalLM(
|
||||
model_class=LlavaNextForConditionalGeneration,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=None,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "mllama":
|
||||
return VlmCausalLM(
|
||||
model_class=MllamaForConditionalGeneration,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=None,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
return CausalLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unsupported model type {model_type}")
|
||||
|
||||
|
||||
|
@ -1,52 +0,0 @@
|
||||
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||
|
||||
import torch
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.pb import generate_pb2
|
||||
|
||||
|
||||
class BloomCausalLMBatch(CausalLMBatch):
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "CausalLMBatch":
|
||||
batch = super().from_pb(
|
||||
pb=pb,
|
||||
tokenizer=tokenizer,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
batch.keys_head_dim_last = False
|
||||
return batch
|
||||
|
||||
|
||||
class BLOOM(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
super(BLOOM, self).__init__(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return BloomCausalLMBatch
|
File diff suppressed because it is too large
Load Diff
@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -415,6 +416,10 @@ class FlashCohereModel(torch.nn.Module):
|
||||
seqlen: torch.Tensor,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, input_ids.shape[0]
|
||||
)
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -26,6 +26,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -678,6 +679,10 @@ class DbrxModel(torch.nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, input_ids.shape[0]
|
||||
)
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -33,6 +33,7 @@ from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
attention,
|
||||
paged_attention,
|
||||
set_block_mapping,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
|
||||
@ -569,6 +570,10 @@ class DeepseekV2Model(torch.nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, input_ids.shape[0]
|
||||
)
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -34,6 +34,7 @@ from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
attention,
|
||||
paged_attention_mla,
|
||||
set_block_mapping,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
|
||||
@ -645,6 +646,10 @@ class DeepseekV3Model(torch.nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, input_ids.shape[0]
|
||||
)
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -466,6 +467,10 @@ class FlashGemma2Model(torch.nn.Module):
|
||||
adapter_data: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, inputs_embeds.shape[0]
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -388,6 +389,10 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
adapter_data: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, inputs_embeds.shape[0]
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -27,6 +27,7 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -383,6 +384,10 @@ class FlashGPT2Model(torch.nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, inputs_embeds.shape[0]
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
residual = None
|
||||
|
@ -28,6 +28,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -324,6 +325,10 @@ class FlashGPTJModel(torch.nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, input_ids.shape[0]
|
||||
)
|
||||
hidden_states = self.wte(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -43,12 +43,12 @@ from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
from text_generation_server.layers.attention import (
|
||||
KVCache,
|
||||
paged_attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
FlashLlamaAttention,
|
||||
LlamaMLP,
|
||||
)
|
||||
|
||||
|
||||
@ -444,7 +444,7 @@ class Llama4TextDecoderLayer(nn.Module):
|
||||
if self.is_moe_layer: # the 128E model interleaves dense / sparse
|
||||
self.feed_forward = Llama4TextMoe(f"{prefix}.feed_forward", config, weights)
|
||||
else:
|
||||
self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights)
|
||||
self.feed_forward = Llama4TextMLP(f"{prefix}.feed_forward", config, weights)
|
||||
|
||||
self.input_layernorm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
@ -549,6 +549,10 @@ class Llama4TextModel(nn.Module):
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, inputs_embeds.shape[0]
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
bs = seqlen.input_lengths.shape[0]
|
||||
@ -1352,55 +1356,36 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
hidden_state = self.vision_model(pixel_values)
|
||||
return hidden_state
|
||||
|
||||
def forward(
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=self.config.vision_config.vision_feature_layer,
|
||||
vision_feature_select_strategy=self.config.vision_config.vision_feature_select_strategy,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
vision_flat = image_features.view(-1, image_features.size(-1))
|
||||
image_features = self.multi_modal_projector(vision_flat)
|
||||
return image_features
|
||||
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
pixel_attention_mask=None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor] = None,
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
slots: torch.Tensor = None,
|
||||
seqlen: Seqlen = None,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
image_sizes: torch.Tensor = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
**lm_kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
def _get_padding_mask(input_ids, pad_token_id=0):
|
||||
return (input_ids != pad_token_id).long()
|
||||
|
||||
attention_mask = _get_padding_mask(input_ids)
|
||||
attention_mask = attention_mask.view(seqlen.input_lengths.shape[0], -1)
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
inputs_embeds = self.text_model.model.embed_tokens(input_ids)
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer
|
||||
if vision_feature_layer is not None
|
||||
else self.config.vision_config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
if vision_embeds is not None:
|
||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||
# that simply don't exist
|
||||
original_inputs_embeds_shape = inputs_embeds.shape
|
||||
|
||||
vision_flat = image_features.view(-1, image_features.size(-1))
|
||||
projected_vision_flat = self.multi_modal_projector(vision_flat)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
||||
-1
|
||||
)
|
||||
@ -1410,19 +1395,33 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
final_mask_1d = final_mask[..., 0].reshape(-1)
|
||||
num_tokens_to_fill = final_mask_1d.sum()
|
||||
|
||||
if num_tokens_to_fill != projected_vision_flat.size(0):
|
||||
if num_tokens_to_fill != vision_embeds.size(0):
|
||||
raise ValueError(
|
||||
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
|
||||
f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
|
||||
f"but multi_modal_projector returned {vision_embeds.size(0)}"
|
||||
)
|
||||
|
||||
expanded_mask = final_mask_1d.unsqueeze(-1).expand(
|
||||
-1, inputs_embeds.size(-1)
|
||||
)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(
|
||||
expanded_mask, projected_vision_flat
|
||||
)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, vision_embeds)
|
||||
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor] = None,
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
slots: torch.Tensor = None,
|
||||
seqlen: Seqlen = None,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
**lm_kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
logits, speculative_logits = self.text_model(
|
||||
inputs_embeds,
|
||||
|
@ -35,6 +35,7 @@ from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoE
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -143,12 +144,14 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
config.num_key_value_heads = getattr(
|
||||
config, "num_key_value_heads", config.num_attention_heads
|
||||
)
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
if config.model_type != "llama4_text":
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
# `config.attention_multiplier` is used in Granite
|
||||
self.softmax_scale = getattr(
|
||||
@ -547,6 +550,11 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
cross_attention_states=None,
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, inputs_embeds.shape[0]
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -163,9 +163,114 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||
# 1. Extract the input embeddings
|
||||
|
||||
# 2. Merge text and images
|
||||
num_images, num_patches, channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.view(
|
||||
num_images * num_patches, channels, height, width
|
||||
)
|
||||
image_features = self.vision_tower(pixel_values)
|
||||
|
||||
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
|
||||
# Already done within the clip model
|
||||
selected_image_feature = image_features.last_hidden_state
|
||||
|
||||
if self.config.vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif self.config.vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
|
||||
)
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
# split up image_features for each of the individual images
|
||||
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
|
||||
# if we assume each image has 5 image features (base image + 4 patches)
|
||||
split_sizes = [num_patches] * num_images
|
||||
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
height = width = (
|
||||
self.config.vision_config.image_size // self.config.vision_config.patch_size
|
||||
)
|
||||
|
||||
new_image_features = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
|
||||
if height * width != base_image_feature.shape[0]:
|
||||
raise ValueError(
|
||||
"The number of patches is not consistent with the image size."
|
||||
)
|
||||
|
||||
# Dimensions are intentionally swapped to be bug-compatible with
|
||||
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
image_feature = image_feature.view(
|
||||
num_patch_height, num_patch_width, height, width, -1
|
||||
)
|
||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
self.image_newline[:, None, None].expand(
|
||||
*image_feature.shape[:-1], 1
|
||||
),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
image_feature = torch.cat(
|
||||
(image_feature, self.image_newline[None]), dim=0
|
||||
)
|
||||
new_image_features.append(image_feature)
|
||||
image_features = torch.stack(new_image_features, dim=0)
|
||||
return image_features.view(-1, image_features.shape[-1])
|
||||
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
|
||||
if vision_embeds is not None:
|
||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||
# that simply don't exist
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, vision_embeds
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -173,101 +278,9 @@ class FlashLlavaNextForConditionalGeneration(nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
# Unused for this model
|
||||
pixel_attention_mask=None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||
# 1. Extract the input embeddings
|
||||
|
||||
# 2. Merge text and images
|
||||
num_images, num_patches, channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.view(
|
||||
num_images * num_patches, channels, height, width
|
||||
)
|
||||
image_features = self.vision_tower(pixel_values)
|
||||
|
||||
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
|
||||
# Already done within the clip model
|
||||
selected_image_feature = image_features.last_hidden_state
|
||||
|
||||
if self.config.vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif self.config.vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
|
||||
)
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
# split up image_features for each of the individual images
|
||||
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
|
||||
# if we assume each image has 5 image features (base image + 4 patches)
|
||||
split_sizes = [num_patches] * num_images
|
||||
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
height = width = (
|
||||
self.config.vision_config.image_size
|
||||
// self.config.vision_config.patch_size
|
||||
)
|
||||
|
||||
new_image_features = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
|
||||
if height * width != base_image_feature.shape[0]:
|
||||
raise ValueError(
|
||||
"The number of patches is not consistent with the image size."
|
||||
)
|
||||
|
||||
# Dimensions are intentionally swapped to be bug-compatible with
|
||||
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
image_feature = image_feature.view(
|
||||
num_patch_height, num_patch_width, height, width, -1
|
||||
)
|
||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
self.image_newline[:, None, None].expand(
|
||||
*image_feature.shape[:-1], 1
|
||||
),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat(
|
||||
(base_image_feature, image_feature), dim=0
|
||||
)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
image_feature = torch.cat(
|
||||
(image_feature, self.image_newline[None]), dim=0
|
||||
)
|
||||
new_image_features.append(image_feature)
|
||||
image_features = torch.stack(new_image_features, dim=0)
|
||||
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, image_features
|
||||
)
|
||||
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -30,6 +30,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -396,6 +397,10 @@ class MistralModel(torch.nn.Module):
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, inputs_embeds.shape[0]
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
|
@ -37,6 +37,7 @@ from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
attention,
|
||||
paged_attention,
|
||||
set_block_mapping,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
@ -446,6 +447,10 @@ class MixtralModel(torch.nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, input_ids.shape[0]
|
||||
)
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
@ -505,7 +510,6 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
|
@ -38,6 +38,7 @@ from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
)
|
||||
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
||||
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
def _prepare_aspect_ratio_attention_mask(
|
||||
@ -236,10 +237,19 @@ class MllamaVisionSdpaAttention(nn.Module):
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask
|
||||
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||
attn_output = fsdpa_op(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
softmax_mode="None",
|
||||
recompute_mode=None,
|
||||
valid_sequence_lengths=None,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
|
||||
|
||||
@ -320,6 +330,9 @@ class MllamaVisionEncoder(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
encoder_states = [hidden_states]
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for encoder_layer in self.layers:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
@ -328,6 +341,8 @@ class MllamaVisionEncoder(nn.Module):
|
||||
|
||||
hidden_states = layer_outputs
|
||||
encoder_states.append(hidden_states)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
return hidden_states, encoder_states
|
||||
|
||||
@ -699,8 +714,6 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
|
||||
# key_states = key_states.repeat(1, self.num_key_value_groups, 1)
|
||||
# value_states = value_states.repeat(1, self.num_key_value_groups, 1)
|
||||
|
||||
causal = False
|
||||
# logger.info(
|
||||
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
|
||||
# )
|
||||
@ -715,7 +728,7 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
value_states,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=causal,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
softmax_mode="None",
|
||||
recompute_mode=None,
|
||||
|
@ -29,6 +29,7 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -354,6 +355,10 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, input_ids.shape[0]
|
||||
)
|
||||
hidden_states = self.embed_in(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -62,10 +62,40 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
self.pad_token_id = (
|
||||
config.pad_token_id if config.pad_token_id is not None else -1
|
||||
)
|
||||
self.dtype = weights.dtype
|
||||
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
pixel_values = pixel_values.to(dtype=self.dtype)
|
||||
image_outputs = self.vision_tower(pixel_values)
|
||||
last_hidden_state = self.post_vision_tower_layernorm(
|
||||
image_outputs.last_hidden_state
|
||||
)
|
||||
image_features = self.multi_modal_projector(last_hidden_state)
|
||||
image_features = image_features.view(-1, image_features.shape[-1])
|
||||
return image_features
|
||||
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
|
||||
if vision_embeds is not None:
|
||||
mask = input_ids == self.config.image_token_index
|
||||
inputs_embeds[mask] = vision_embeds
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -73,32 +103,13 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
# Unused here
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
# TODO This is odd but apparently pali gemma position ids start at 1.
|
||||
if cu_seqlen_prefill is not None:
|
||||
position_ids += 1
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
|
||||
image_outputs = self.vision_tower(pixel_values)
|
||||
last_hidden_state = self.post_vision_tower_layernorm(
|
||||
image_outputs.last_hidden_state
|
||||
)
|
||||
image_features = self.multi_modal_projector(last_hidden_state)
|
||||
|
||||
# mask where image or padding tokens
|
||||
mask = input_ids == self.config.image_token_index
|
||||
|
||||
# insert image features into input embeddings
|
||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
|
@ -9,6 +9,7 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -347,6 +348,10 @@ class FlashPhiModel(torch.nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, input_ids.shape[0]
|
||||
)
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -8,6 +8,7 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -288,6 +289,10 @@ class Qwen2Model(torch.nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, inputs_embeds.shape[0]
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
@ -359,7 +364,6 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
hidden_states = self.model(
|
||||
|
@ -18,6 +18,7 @@ import habana_frameworks.torch as htorch
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -266,7 +267,10 @@ class Qwen3Model(nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, inputs_embeds.shape[0]
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
@ -334,7 +338,6 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
hidden_states = self.model(
|
||||
|
@ -18,6 +18,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.attention import (
|
||||
attention,
|
||||
paged_attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -628,6 +629,10 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, input_ids.shape[0]
|
||||
)
|
||||
hidden_states = self.word_embeddings(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -8,6 +8,7 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -437,6 +438,10 @@ class FlashSantacoderModel(nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, input_ids.shape[0]
|
||||
)
|
||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
|
@ -29,6 +29,7 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
set_block_mapping,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
@ -511,6 +512,10 @@ class Starcoder2Model(torch.nn.Module):
|
||||
adapter_data,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
if hpu_attention_meta is not None:
|
||||
hpu_attention_meta = set_block_mapping(
|
||||
hpu_attention_meta, input_ids.shape[0]
|
||||
)
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
@ -584,7 +589,6 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
|
@ -734,9 +734,107 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
assert pixel_values is not None
|
||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||
all_states = []
|
||||
all_pixel_values = pixel_values
|
||||
all_pixel_mask = pixel_attention_mask
|
||||
for i in range(batch_size):
|
||||
pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility
|
||||
pixel_values = pixel_values[i : i + 1]
|
||||
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
||||
|
||||
# Remove padding images - padding images are full 0.
|
||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||
real_images_inds = (pixel_values == 0.0).sum(
|
||||
dim=(-1, -2, -3)
|
||||
) != nb_values_per_image
|
||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||
|
||||
# Handle the vision attention mask
|
||||
if pixel_attention_mask is None:
|
||||
pixel_attention_mask = torch.ones(
|
||||
size=(
|
||||
pixel_values.size(0),
|
||||
pixel_values.size(2),
|
||||
pixel_values.size(3),
|
||||
),
|
||||
dtype=torch.bool,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
else:
|
||||
# Remove padding images from the mask/pP p
|
||||
pixel_attention_mask = all_pixel_mask[i : i + 1]
|
||||
pixel_attention_mask = pixel_attention_mask.view(
|
||||
1 * num_images, *pixel_attention_mask.shape[2:]
|
||||
)
|
||||
pixel_attention_mask = pixel_attention_mask[
|
||||
real_images_inds
|
||||
].contiguous()
|
||||
|
||||
patch_size = self.config.vision_config.patch_size
|
||||
"""
|
||||
patches_subgrid = pixel_attention_mask.unfold(
|
||||
dimension=1, size=patch_size, step=patch_size
|
||||
)
|
||||
patches_subgrid = patches_subgrid.unfold(
|
||||
dimension=2, size=patch_size, step=patch_size
|
||||
)
|
||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||
"""
|
||||
# hpu does none support unfold
|
||||
conv_kernel = torch.ones(
|
||||
[1, 1, patch_size, patch_size],
|
||||
dtype=pixel_values.dtype,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
patches_subgrid = torch.nn.functional.conv2d(
|
||||
pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),
|
||||
conv_kernel,
|
||||
stride=patch_size,
|
||||
).squeeze(1)
|
||||
patch_attention_mask = torch.gt(patches_subgrid, 0)
|
||||
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
)
|
||||
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.connector(
|
||||
image_hidden_states,
|
||||
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
|
||||
)
|
||||
all_states.append(image_hidden_states)
|
||||
image_hidden_states = torch.stack(all_states, dim=0)
|
||||
return image_hidden_states.view(-1, image_hidden_states.shape[-1])
|
||||
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
|
||||
if vision_embeds is not None:
|
||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||
# that simply don't exist
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, vision_embeds
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -744,98 +842,9 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
# Unused here
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||
all_states = []
|
||||
all_pixel_values = pixel_values
|
||||
all_pixel_mask = pixel_attention_mask
|
||||
for i in range(batch_size):
|
||||
pixel_values = all_pixel_values.to(
|
||||
dtype=self.dtype
|
||||
) # fp16 compatibility
|
||||
pixel_values = pixel_values[i : i + 1]
|
||||
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
||||
|
||||
# Remove padding images - padding images are full 0.
|
||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||
real_images_inds = (pixel_values == 0.0).sum(
|
||||
dim=(-1, -2, -3)
|
||||
) != nb_values_per_image
|
||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||
|
||||
# Handle the vision attention mask
|
||||
if pixel_attention_mask is None:
|
||||
pixel_attention_mask = torch.ones(
|
||||
size=(
|
||||
pixel_values.size(0),
|
||||
pixel_values.size(2),
|
||||
pixel_values.size(3),
|
||||
),
|
||||
dtype=torch.bool,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
else:
|
||||
# Remove padding images from the mask/pP p
|
||||
pixel_attention_mask = all_pixel_mask[i : i + 1]
|
||||
pixel_attention_mask = pixel_attention_mask.view(
|
||||
1 * num_images, *pixel_attention_mask.shape[2:]
|
||||
)
|
||||
pixel_attention_mask = pixel_attention_mask[
|
||||
real_images_inds
|
||||
].contiguous()
|
||||
|
||||
patch_size = self.config.vision_config.patch_size
|
||||
"""
|
||||
patches_subgrid = pixel_attention_mask.unfold(
|
||||
dimension=1, size=patch_size, step=patch_size
|
||||
)
|
||||
patches_subgrid = patches_subgrid.unfold(
|
||||
dimension=2, size=patch_size, step=patch_size
|
||||
)
|
||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||
"""
|
||||
# hpu does none support unfold
|
||||
conv_kernel = torch.ones(
|
||||
[1, 1, patch_size, patch_size],
|
||||
dtype=pixel_values.dtype,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
patches_subgrid = torch.nn.functional.conv2d(
|
||||
pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),
|
||||
conv_kernel,
|
||||
stride=patch_size,
|
||||
).squeeze(1)
|
||||
patch_attention_mask = torch.eq(
|
||||
patches_subgrid, (patch_size * patch_size)
|
||||
)
|
||||
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
)
|
||||
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.connector(
|
||||
image_hidden_states,
|
||||
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
|
||||
)
|
||||
all_states.append(image_hidden_states)
|
||||
image_hidden_states = torch.stack(all_states, dim=0)
|
||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||
# that simply don't exist
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, image_hidden_states
|
||||
)
|
||||
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
|
@ -477,9 +477,107 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||
all_states = []
|
||||
all_pixel_values = pixel_values
|
||||
all_pixel_mask = pixel_attention_mask
|
||||
for i in range(batch_size):
|
||||
pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility
|
||||
pixel_values = pixel_values[i : i + 1]
|
||||
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
||||
|
||||
# Remove padding images - padding images are full 0.
|
||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||
real_images_inds = (pixel_values == 0.0).sum(
|
||||
dim=(-1, -2, -3)
|
||||
) != nb_values_per_image
|
||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||
# Handle the vision attention mask
|
||||
if pixel_attention_mask is None:
|
||||
pixel_attention_mask = torch.ones(
|
||||
size=(
|
||||
pixel_values.size(0),
|
||||
pixel_values.size(2),
|
||||
pixel_values.size(3),
|
||||
),
|
||||
dtype=torch.bool,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
else:
|
||||
# Remove padding images from the mask/pP p
|
||||
pixel_attention_mask = all_pixel_mask[i : i + 1]
|
||||
pixel_attention_mask = pixel_attention_mask.view(
|
||||
1 * num_images, *pixel_attention_mask.shape[2:]
|
||||
)
|
||||
pixel_attention_mask = pixel_attention_mask[
|
||||
real_images_inds
|
||||
].contiguous()
|
||||
|
||||
patch_size = self.config.vision_config.patch_size
|
||||
|
||||
"""
|
||||
patches_subgrid = pixel_attention_mask.unfold(
|
||||
dimension=1, size=patch_size, step=patch_size
|
||||
)
|
||||
patches_subgrid = patches_subgrid.unfold(
|
||||
dimension=2, size=patch_size, step=patch_size
|
||||
)
|
||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||
"""
|
||||
# hpu does none support unfold
|
||||
conv_kernel = torch.ones(
|
||||
[1, 1, patch_size, patch_size],
|
||||
dtype=pixel_values.dtype,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
patches_subgrid = torch.nn.functional.conv2d(
|
||||
pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),
|
||||
conv_kernel,
|
||||
stride=patch_size,
|
||||
).squeeze(1)
|
||||
patch_attention_mask = torch.gt(patches_subgrid, 0)
|
||||
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
)
|
||||
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.connector(
|
||||
image_hidden_states,
|
||||
)
|
||||
|
||||
all_states.append(image_hidden_states)
|
||||
image_hidden_states = torch.stack(all_states, dim=0)
|
||||
|
||||
return image_hidden_states.view(-1, image_hidden_states.shape[-1])
|
||||
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
|
||||
if vision_embeds is not None:
|
||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||
# that simply don't exist
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, vision_embeds
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -487,99 +585,10 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
# Unused here
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
image_indices=None,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||
all_states = []
|
||||
all_pixel_values = pixel_values
|
||||
all_pixel_mask = pixel_attention_mask
|
||||
for i in range(batch_size):
|
||||
pixel_values = all_pixel_values.to(
|
||||
dtype=self.dtype
|
||||
) # fp16 compatibility
|
||||
pixel_values = pixel_values[i : i + 1]
|
||||
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
||||
|
||||
# Remove padding images - padding images are full 0.
|
||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||
real_images_inds = (pixel_values == 0.0).sum(
|
||||
dim=(-1, -2, -3)
|
||||
) != nb_values_per_image
|
||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||
# Handle the vision attention mask
|
||||
if pixel_attention_mask is None:
|
||||
pixel_attention_mask = torch.ones(
|
||||
size=(
|
||||
pixel_values.size(0),
|
||||
pixel_values.size(2),
|
||||
pixel_values.size(3),
|
||||
),
|
||||
dtype=torch.bool,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
else:
|
||||
# Remove padding images from the mask/pP p
|
||||
pixel_attention_mask = all_pixel_mask[i : i + 1]
|
||||
pixel_attention_mask = pixel_attention_mask.view(
|
||||
1 * num_images, *pixel_attention_mask.shape[2:]
|
||||
)
|
||||
pixel_attention_mask = pixel_attention_mask[
|
||||
real_images_inds
|
||||
].contiguous()
|
||||
|
||||
patch_size = self.config.vision_config.patch_size
|
||||
"""
|
||||
patches_subgrid = pixel_attention_mask.unfold(
|
||||
dimension=1, size=patch_size, step=patch_size
|
||||
)
|
||||
patches_subgrid = patches_subgrid.unfold(
|
||||
dimension=2, size=patch_size, step=patch_size
|
||||
)
|
||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||
"""
|
||||
# hpu does none support unfold
|
||||
conv_kernel = torch.ones(
|
||||
[1, 1, patch_size, patch_size],
|
||||
dtype=pixel_values.dtype,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
patches_subgrid = torch.nn.functional.conv2d(
|
||||
pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),
|
||||
conv_kernel,
|
||||
stride=patch_size,
|
||||
).squeeze(1)
|
||||
patch_attention_mask = torch.eq(
|
||||
patches_subgrid, (patch_size * patch_size)
|
||||
)
|
||||
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
)
|
||||
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.connector(
|
||||
image_hidden_states,
|
||||
)
|
||||
|
||||
all_states.append(image_hidden_states)
|
||||
image_hidden_states = torch.stack(all_states, dim=0)
|
||||
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, image_hidden_states
|
||||
)
|
||||
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
|
@ -1,467 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Llava-NeXT model."""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
import numpy as np
|
||||
|
||||
from loguru import logger
|
||||
from transformers.models.llava_next.modeling_llava_next import (
|
||||
unpad_image,
|
||||
)
|
||||
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
"""
|
||||
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||||
|
||||
Args:
|
||||
image_size (`tuple`):
|
||||
The size of the input image in the format (width, height).
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
patch_size (`int`):
|
||||
The size of each image patch.
|
||||
|
||||
Returns:
|
||||
tuple: The shape of the image patch grid in the format (width, height).
|
||||
"""
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
||||
|
||||
height, width = select_best_resolution(image_size, grid_pinpoints)
|
||||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79
|
||||
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
|
||||
"""
|
||||
Calculate the number of patches after the preprocessing for images of any resolution.
|
||||
|
||||
Args:
|
||||
image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`):
|
||||
The size of the input image in the format (height, width). ?
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
patch_size (`int`):
|
||||
The size of each image patch.
|
||||
|
||||
Returns:
|
||||
int: the number of patches
|
||||
"""
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise TypeError("grid_pinpoints should be a list of tuples or lists")
|
||||
|
||||
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
|
||||
if not isinstance(image_size, (list, tuple)):
|
||||
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
|
||||
raise TypeError(
|
||||
f"image_size invalid type {type(image_size)} with value {image_size}"
|
||||
)
|
||||
image_size = image_size.tolist()
|
||||
|
||||
best_resolution = select_best_resolution(image_size, grid_pinpoints)
|
||||
height, width = best_resolution
|
||||
num_patches = 0
|
||||
# consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
|
||||
for i in range(0, height, patch_size):
|
||||
for j in range(0, width, patch_size):
|
||||
num_patches += 1
|
||||
# add the base patch
|
||||
num_patches += 1
|
||||
return num_patches
|
||||
|
||||
|
||||
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[int] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
token_idx: Optional[torch.Tensor] = None,
|
||||
use_flash_attention: Optional[bool] = True,
|
||||
flash_attention_recompute: Optional[bool] = True,
|
||||
):
|
||||
|
||||
if token_idx is not None:
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
token_idx=token_idx,
|
||||
use_flash_attention=use_flash_attention,
|
||||
flash_attention_recompute=flash_attention_recompute,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return output
|
||||
|
||||
return outputs
|
||||
|
||||
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411
|
||||
def pack_image_features(
|
||||
self,
|
||||
image_features,
|
||||
image_sizes,
|
||||
vision_feature_select_strategy,
|
||||
image_newline=None,
|
||||
):
|
||||
"""
|
||||
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
|
||||
|
||||
Args:
|
||||
image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
|
||||
List of image feature tensor, each contains all the visual feature of all patches.
|
||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||
Actual image size of each images (H, W).
|
||||
vision_feature_select_strategy (`str`)
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
image_newline (`torch.Tensor` of shape `(embed_dim)`)
|
||||
New line embedding vector.
|
||||
Returns:
|
||||
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
|
||||
feature_lens (`List[int]`)
|
||||
token length of each image in image_features
|
||||
"""
|
||||
new_image_features = []
|
||||
feature_lens = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
height = width = (
|
||||
self.config.vision_config.image_size
|
||||
// self.config.vision_config.patch_size
|
||||
)
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
|
||||
if (
|
||||
np.prod(image_feature.shape)
|
||||
% (num_patch_height * num_patch_width * height * width)
|
||||
!= 0
|
||||
and vision_feature_select_strategy == "default"
|
||||
):
|
||||
logger.warning_once(
|
||||
"Image feature shape does not line up with the provided patch size. "
|
||||
"You may be using the `default` vision_feature_select_strategy with a"
|
||||
" visual encoder that does not have CLS."
|
||||
)
|
||||
|
||||
image_feature = image_feature.view(
|
||||
num_patch_height, num_patch_width, height, width, -1
|
||||
)
|
||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||
if image_newline is not None:
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
image_newline[:, None, None]
|
||||
.expand(*image_feature.shape[:-1], 1)
|
||||
.to(image_feature.device, image_feature.dtype),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
if image_newline is not None:
|
||||
image_feature = torch.cat(
|
||||
(image_feature, image_newline[None].to(image_feature)), dim=0
|
||||
)
|
||||
new_image_features.append(image_feature)
|
||||
feature_lens.append(image_feature.size(0))
|
||||
image_features = torch.cat(new_image_features, dim=0)
|
||||
feature_lens = torch.tensor(
|
||||
feature_lens, dtype=torch.long, device=image_features.device
|
||||
)
|
||||
return image_features, feature_lens
|
||||
|
||||
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_sizes: torch.Tensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
vision_feature_select_strategy: str,
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||
Actual image size of each images (H, W).
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
vision_feature_select_strategy (`str`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
|
||||
and are of shape `(num_patches, image_length, embed_dim)`).
|
||||
"""
|
||||
# ! infer image_num_patches from image_sizes
|
||||
image_num_patches = [
|
||||
image_size_to_num_patches(
|
||||
image_size=imsize,
|
||||
grid_pinpoints=self.config.image_grid_pinpoints,
|
||||
patch_size=self.config.vision_config.image_size,
|
||||
)
|
||||
for imsize in image_sizes
|
||||
]
|
||||
if pixel_values.dim() == 5:
|
||||
# stacked if input is (batch_size, num_patches, num_channels, height, width)
|
||||
_pixel_values_list = [
|
||||
pix_val[:num_patch]
|
||||
for pix_val, num_patch in zip(pixel_values, image_num_patches)
|
||||
]
|
||||
pixel_values = torch.cat(_pixel_values_list, dim=0)
|
||||
elif pixel_values.dim() != 4:
|
||||
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
|
||||
raise ValueError(
|
||||
f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions"
|
||||
)
|
||||
|
||||
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
# If we have one vision feature layer, return the corresponding hidden states,
|
||||
# otherwise, select the hidden states of each feature layer and concatenate them
|
||||
if isinstance(vision_feature_layer, int):
|
||||
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
||||
else:
|
||||
hs_pool = [
|
||||
image_features.hidden_states[layer_idx]
|
||||
for layer_idx in vision_feature_layer
|
||||
]
|
||||
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
||||
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
return image_features
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
pixel_values=None,
|
||||
image_sizes=None,
|
||||
attention_mask=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635
|
||||
The only differences are:
|
||||
- add new args token_idx
|
||||
- add the process of merging images into inputs_embeds
|
||||
"""
|
||||
token_idx = kwargs.get("token_idx", None)
|
||||
if token_idx is None:
|
||||
return super().prepare_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
pixel_values=pixel_values,
|
||||
image_sizes=image_sizes,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
use_flash_attention = kwargs.get("use_flash_attention", True)
|
||||
flash_attention_recompute = kwargs.get("flash_attention_recompute", True)
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
labels = kwargs.get("labels", None)
|
||||
if (
|
||||
past_key_values is None
|
||||
and pixel_values is not None
|
||||
and input_ids.shape[1] != 1
|
||||
):
|
||||
vision_feature_select_strategy = kwargs.get(
|
||||
"vision_feature_select_strategy", None
|
||||
)
|
||||
vision_feature_layer = kwargs.get("vision_feature_layer", None)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer
|
||||
if vision_feature_layer is not None
|
||||
else self.config.vision_feature_layer
|
||||
)
|
||||
|
||||
# 1. Extract the input embeddings
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
# 2. Merge text and images
|
||||
image_features = self.get_image_features(
|
||||
pixel_values,
|
||||
image_sizes,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
image_newline=self.image_newline,
|
||||
)
|
||||
|
||||
special_image_mask = (
|
||||
input_ids == self.config.image_token_index
|
||||
).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
||||
inputs_embeds.device
|
||||
)
|
||||
if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_features = image_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
|
||||
image_features = image_features.to(
|
||||
inputs_embeds.device, inputs_embeds.dtype
|
||||
)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(
|
||||
special_image_mask, image_features
|
||||
)
|
||||
|
||||
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
||||
# generation with cache
|
||||
elif past_key_values is not None:
|
||||
seq_len = input_ids.shape[1]
|
||||
pad_len = seq_len - token_idx.item()
|
||||
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(
|
||||
first_layer_past_key_value.float().sum(-2) == 0
|
||||
)
|
||||
# Get the target length
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
|
||||
attention_mask = extended_attention_mask
|
||||
attention_mask[:, -pad_len:] = 0
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
if token_idx is not None:
|
||||
position_ids = (
|
||||
torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
)
|
||||
else:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
"token_idx": token_idx,
|
||||
"labels": labels,
|
||||
"use_flash_attention": use_flash_attention,
|
||||
"flash_attention_recompute": flash_attention_recompute,
|
||||
}
|
||||
)
|
||||
|
||||
return model_inputs
|
@ -1,292 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch Mllama model."""
|
||||
|
||||
from typing import Optional, Tuple, List, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from optimum.habana.transformers.models import GaudiMllamaForConditionalGeneration
|
||||
from optimum.habana.transformers.models.mllama.modeling_mllama import (
|
||||
_prepare_cross_attention_mask,
|
||||
)
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
|
||||
class MllamaForConditionalGeneration(GaudiMllamaForConditionalGeneration):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
aspect_ratio_mask: Optional[torch.Tensor] = None,
|
||||
aspect_ratio_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
token_idx: Optional[torch.Tensor] = None,
|
||||
use_flash_attention: Optional[bool] = True,
|
||||
flash_attention_recompute: Optional[bool] = True,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
"""
|
||||
Copied from MllamaForConditionalGeneration::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2077
|
||||
The only differences are:
|
||||
- add token_idx input
|
||||
- add use_flash_attention and flash_attention_recompute
|
||||
"""
|
||||
full_text_row_masked_out_mask = kwargs.get(
|
||||
"full_text_row_masked_out_mask", None
|
||||
)
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
outputs = self.language_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
cross_attention_states=cross_attention_states,
|
||||
cross_attention_mask=cross_attention_mask,
|
||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
inputs_embeds=inputs_embeds,
|
||||
labels=labels,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
num_logits_to_keep=num_logits_to_keep,
|
||||
token_idx=token_idx,
|
||||
use_flash_attention=use_flash_attention,
|
||||
flash_attention_recompute=flash_attention_recompute,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return output
|
||||
|
||||
return outputs
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs_embeds=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
pixel_values=None,
|
||||
aspect_ratio_ids=None,
|
||||
aspect_ratio_mask=None,
|
||||
cross_attention_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=False,
|
||||
cache_position=None,
|
||||
num_logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Copied from MllamaForConditionalGeneration::prepare_inputs_for_generation: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2208
|
||||
The only differences are:
|
||||
- add token_idx handling
|
||||
- add bucket_internal handling
|
||||
- add use_flash_attention and flash_attention_recompute
|
||||
"""
|
||||
|
||||
token_idx = kwargs.get("token_idx", None)
|
||||
if token_idx is None:
|
||||
return super().prepare_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
pixel_values=pixel_values,
|
||||
aspect_ratio_ids=aspect_ratio_ids,
|
||||
aspect_ratio_mask=aspect_ratio_mask,
|
||||
cross_attention_mask=cross_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
use_flash_attention = kwargs.get("use_flash_attention", True)
|
||||
flash_attention_recompute = kwargs.get("flash_attention_recompute", True)
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
output_attentions = kwargs.get("output_attentions", None)
|
||||
output_hidden_states = kwargs.get("output_hidden_states", None)
|
||||
return_dict = kwargs.get("return_dict", None)
|
||||
labels = kwargs.get("labels", None)
|
||||
cross_attention_states = kwargs.get("cross_attention_states", None)
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
bucket_internal = kwargs.get("bucket_internal", None)
|
||||
|
||||
if past_key_values is not None:
|
||||
if token_idx is not None:
|
||||
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
|
||||
elif inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif (
|
||||
input_ids.shape[1] != cache_position.shape[0]
|
||||
): # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
elif bucket_internal and token_idx is not None:
|
||||
# for the 1st token we can slice the inputs till token idx for the fwd pass.
|
||||
input_ids = input_ids[:, :token_idx]
|
||||
attention_mask = attention_mask[:, :token_idx]
|
||||
if cross_attention_mask is not None:
|
||||
cross_attention_mask = cross_attention_mask[:, :token_idx, ...]
|
||||
|
||||
# TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
if token_idx is not None:
|
||||
position_ids = torch.index_select(
|
||||
position_ids, 1, token_idx - 1
|
||||
)
|
||||
else:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
||||
position_ids = position_ids.clone(
|
||||
memory_format=torch.contiguous_format
|
||||
)
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if pixel_values is not None and cross_attention_states is not None:
|
||||
raise ValueError(
|
||||
"`pixel_values` and `cross_attention_states` cannot be provided simultaneously"
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
if aspect_ratio_ids is None:
|
||||
raise ValueError(
|
||||
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
|
||||
)
|
||||
# get vision tokens from vision model
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
aspect_ratio_ids=aspect_ratio_ids,
|
||||
aspect_ratio_mask=aspect_ratio_mask,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=return_dict,
|
||||
use_flash_attention=use_flash_attention,
|
||||
)
|
||||
cross_attention_states = vision_outputs[0]
|
||||
cross_attention_states = self.multi_modal_projector(
|
||||
cross_attention_states
|
||||
).reshape(-1, cross_attention_states.shape[-2], self.hidden_size)
|
||||
|
||||
if cross_attention_mask is not None:
|
||||
cross_attention_mask, full_text_row_masked_out_mask = (
|
||||
_prepare_cross_attention_mask(
|
||||
cross_attention_mask,
|
||||
num_vision_tokens=self.vision_model.num_patches,
|
||||
dtype=self.dtype,
|
||||
token_idx=token_idx,
|
||||
)
|
||||
)
|
||||
else:
|
||||
full_text_row_masked_out_mask = None
|
||||
|
||||
if cross_attention_mask is not None:
|
||||
if cache_position is not None:
|
||||
cross_attention_mask = cross_attention_mask[:, :, cache_position]
|
||||
full_text_row_masked_out_mask = full_text_row_masked_out_mask[
|
||||
:, :, cache_position
|
||||
]
|
||||
elif past_key_values is not None:
|
||||
if token_idx is not None:
|
||||
cross_attention_mask = torch.index_select(
|
||||
cross_attention_mask, -2, token_idx - 1
|
||||
)
|
||||
full_text_row_masked_out_mask = torch.index_select(
|
||||
full_text_row_masked_out_mask, -2, token_idx - 1
|
||||
)
|
||||
else:
|
||||
cross_attention_mask = cross_attention_mask[:, :, -1:]
|
||||
full_text_row_masked_out_mask = full_text_row_masked_out_mask[
|
||||
:, :, -1:
|
||||
]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {
|
||||
"input_ids": input_ids.clone(memory_format=torch.contiguous_format),
|
||||
"inputs_embeds": None,
|
||||
}
|
||||
|
||||
if num_logits_to_keep is not None:
|
||||
model_inputs["num_logits_to_keep"] = num_logits_to_keep
|
||||
|
||||
# keep cache_position implementation as None for HPU
|
||||
cache_position = None
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
"token_idx": token_idx,
|
||||
"labels": labels,
|
||||
"return_dict": kwargs.get("return_dict"),
|
||||
"full_text_row_masked_out_mask": full_text_row_masked_out_mask,
|
||||
"use_flash_attention": use_flash_attention,
|
||||
"cross_attention_mask": cross_attention_mask,
|
||||
"cross_attention_states": cross_attention_states,
|
||||
"output_attentions": output_attentions,
|
||||
"flash_attention_recompute": flash_attention_recompute,
|
||||
}
|
||||
)
|
||||
|
||||
return model_inputs
|
@ -45,11 +45,17 @@ from text_generation_server.layers.attention import (
|
||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||
Qwen2Model,
|
||||
)
|
||||
from habana_frameworks.torch.hpex.kernels import (
|
||||
RotaryPosEmbeddingMode,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py
|
||||
from typing import Union
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.image_utils import ImageInput, VideoInput
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.video_utils import VideoInput
|
||||
from transformers.processing_utils import (
|
||||
ProcessingKwargs,
|
||||
ProcessorMixin,
|
||||
@ -375,28 +381,6 @@ class Qwen2_5_VLConfig(PretrainedConfig):
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(
|
||||
tensor: torch.Tensor, freqs: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
orig_dtype = tensor.dtype
|
||||
tensor = tensor.float()
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
||||
output = output.to(orig_dtype)
|
||||
return output
|
||||
|
||||
|
||||
class Qwen2_5VLAttention(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights):
|
||||
super().__init__()
|
||||
@ -426,7 +410,8 @@ class Qwen2_5VLAttention(nn.Module):
|
||||
self,
|
||||
hidden_state: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
max_seqlen: int,
|
||||
) -> torch.Tensor:
|
||||
# apply the qkv linear layer to the hidden state
|
||||
@ -444,29 +429,37 @@ class Qwen2_5VLAttention(nn.Module):
|
||||
query = query.view(*_shape)
|
||||
key = key.view(*_shape)
|
||||
value = value.view(*_shape)
|
||||
|
||||
# apply rotary positional embeddings
|
||||
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
|
||||
0
|
||||
)
|
||||
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
|
||||
rotary_dim = cos.shape[-1]
|
||||
query_rot = query[..., :rotary_dim]
|
||||
query_pass = query[..., rotary_dim:]
|
||||
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
|
||||
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape))
|
||||
|
||||
# calc maximum sequence length for any batch
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
causal = False
|
||||
key_rot = key[..., :rotary_dim]
|
||||
key_pass = key[..., rotary_dim:]
|
||||
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
||||
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
|
||||
|
||||
# execute sdpa
|
||||
query = query.unsqueeze(0).transpose(1, 2)
|
||||
key = key.unsqueeze(0).transpose(1, 2)
|
||||
value = value.unsqueeze(0).transpose(1, 2)
|
||||
causal = False
|
||||
query = query.transpose(0, 1)
|
||||
key = key.transpose(0, 1)
|
||||
value = value.transpose(0, 1)
|
||||
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||
attention_mask = torch.zeros(
|
||||
[1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[
|
||||
:, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]
|
||||
] = True
|
||||
attn_output = fsdpa_op(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=causal,
|
||||
scale=None,
|
||||
@ -474,7 +467,7 @@ class Qwen2_5VLAttention(nn.Module):
|
||||
recompute_mode=None,
|
||||
valid_sequence_lengths=None,
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
|
||||
# reshape output to original dimensions
|
||||
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
|
||||
@ -533,11 +526,9 @@ class Qwen2_5VLVisionBlock(nn.Module):
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
|
||||
) -> torch.Tensor:
|
||||
def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor:
|
||||
norm1_out, _ = self.norm1(hidden_states)
|
||||
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
|
||||
attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen)
|
||||
hidden_states = hidden_states + attn_out
|
||||
norm2_out, _ = self.norm2(hidden_states)
|
||||
mlp_out = self.mlp(norm2_out)
|
||||
@ -608,7 +599,7 @@ class Qwen2_5VisionModel(nn.Module):
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
self.temporal_patch_size = config.temporal_patch_size
|
||||
self.spatial_patch_size = config.spatial_patch_size
|
||||
self.in_channels = config.in_channels
|
||||
@ -736,6 +727,10 @@ class Qwen2_5VisionModel(nn.Module):
|
||||
)
|
||||
|
||||
rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device)
|
||||
cos = rotary_pos_emb.cos()
|
||||
sin = rotary_pos_emb.sin()
|
||||
cos = torch.cat((cos, cos), dim=-1).unsqueeze(1)
|
||||
sin = torch.cat((sin, sin), dim=-1).unsqueeze(1)
|
||||
|
||||
cu_window_seqlens = torch.tensor(
|
||||
cu_window_seqlens,
|
||||
@ -754,6 +749,9 @@ class Qwen2_5VisionModel(nn.Module):
|
||||
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
|
||||
|
||||
# iterately apply the blocks to the hidden states
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for layer_num, block in enumerate(self.blocks):
|
||||
# NOTE: qwen2_5_vl.py has a concept of full attention blocks
|
||||
# that are applied at specific layers.
|
||||
@ -762,9 +760,9 @@ class Qwen2_5VisionModel(nn.Module):
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
|
||||
hidden_states = block(
|
||||
hidden_states, cu_seqlens_now, rotary_pos_emb, max_seqlen
|
||||
)
|
||||
hidden_states = block(hidden_states, cu_seqlens_now, cos, sin, max_seqlen)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
# apply the final patch merger to the hidden states
|
||||
hidden_states = self.merger(hidden_states)
|
||||
@ -886,9 +884,6 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
||||
full_llm_pos_ids_list = [
|
||||
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
|
||||
]
|
||||
# import ipdb
|
||||
|
||||
# ipdb.set_trace()
|
||||
max_s = full_llm_pos_ids_list[-1].max() + 1
|
||||
final_text_len = input_ids_len - vision_ends[-1]
|
||||
if final_text_len > 0:
|
||||
@ -900,9 +895,33 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
||||
)
|
||||
return position_ids
|
||||
|
||||
def forward(
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
|
||||
return image_embeds
|
||||
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# apply the visual model to the pixel values if they are provided
|
||||
if vision_embeds is not None:
|
||||
mask = torch.where(input_ids == self.image_token_id)
|
||||
inputs_embeds[mask] = vision_embeds
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -910,26 +929,10 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
lm_head_indices: Optional[torch.Tensor],
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
# Unused in this model
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
pixel_attention_mask=None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
image_indices=None,
|
||||
):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# apply the visual model to the pixel values if they are provided
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.visual(
|
||||
pixel_values, grid_thw=image_grid_thw
|
||||
).squeeze(0)
|
||||
mask = torch.where(input_ids == self.image_token_id)
|
||||
inputs_embeds[mask] = image_embeds
|
||||
|
||||
hidden_states = self.text_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -44,28 +44,11 @@ from text_generation_server.layers.attention import (
|
||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||
Qwen2Model,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(
|
||||
tensor: torch.Tensor, freqs: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
orig_dtype = tensor.dtype
|
||||
tensor = tensor.float()
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
||||
output = output.to(orig_dtype)
|
||||
return output
|
||||
from habana_frameworks.torch.hpex.kernels import (
|
||||
RotaryPosEmbeddingMode,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class Qwen2VLAttention(nn.Module):
|
||||
@ -96,7 +79,8 @@ class Qwen2VLAttention(nn.Module):
|
||||
self,
|
||||
hidden_state: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
max_seqlen: int,
|
||||
) -> torch.Tensor:
|
||||
# apply the qkv linear layer to the hidden state
|
||||
@ -116,27 +100,36 @@ class Qwen2VLAttention(nn.Module):
|
||||
value = value.view(*_shape)
|
||||
|
||||
# apply rotary positional embeddings
|
||||
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
|
||||
0
|
||||
)
|
||||
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
|
||||
rotary_dim = cos.shape[-1]
|
||||
query_rot = query[..., :rotary_dim]
|
||||
query_pass = query[..., rotary_dim:]
|
||||
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
|
||||
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape))
|
||||
|
||||
# calc maximum sequence length for any batch
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
causal = False
|
||||
key_rot = key[..., :rotary_dim]
|
||||
key_pass = key[..., rotary_dim:]
|
||||
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
||||
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
|
||||
|
||||
# execute sdpa
|
||||
query = query.unsqueeze(0).transpose(1, 2)
|
||||
key = key.unsqueeze(0).transpose(1, 2)
|
||||
value = value.unsqueeze(0).transpose(1, 2)
|
||||
causal = False
|
||||
query = query.transpose(0, 1)
|
||||
key = key.transpose(0, 1)
|
||||
value = value.transpose(0, 1)
|
||||
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||
attention_mask = torch.zeros(
|
||||
[1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[
|
||||
:, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]
|
||||
] = True
|
||||
attn_output = fsdpa_op(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=causal,
|
||||
scale=None,
|
||||
@ -144,7 +137,7 @@ class Qwen2VLAttention(nn.Module):
|
||||
recompute_mode=None,
|
||||
valid_sequence_lengths=None,
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
# reshape output to original dimensions
|
||||
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
|
||||
attn_output = self.proj(attn_output)
|
||||
@ -193,11 +186,9 @@ class Qwen2VLVisionBlock(nn.Module):
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
|
||||
) -> torch.Tensor:
|
||||
def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor:
|
||||
norm1_out, residual = self.norm1(hidden_states)
|
||||
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
|
||||
attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen)
|
||||
hidden_states = attn_out + residual
|
||||
norm2_out, residual = self.norm2(hidden_states)
|
||||
hidden_states = hidden_states + self.mlp(norm2_out)
|
||||
@ -330,6 +321,11 @@ class Qwen2VisionModel(nn.Module):
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype)
|
||||
|
||||
cos = rotary_pos_emb.cos()
|
||||
sin = rotary_pos_emb.sin()
|
||||
cos = torch.cat((cos, cos), dim=-1).unsqueeze(1)
|
||||
sin = torch.cat((sin, sin), dim=-1).unsqueeze(1)
|
||||
|
||||
# create a cu_seqlens tensor to be used in the attention mask
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
@ -337,8 +333,13 @@ class Qwen2VisionModel(nn.Module):
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
|
||||
# iterately apply the blocks to the hidden states
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for block in self.blocks:
|
||||
hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen)
|
||||
hidden_states = block(hidden_states, cu_seqlens, cos, sin, max_seqlen)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
# apply the final patch merger to the hidden states
|
||||
hidden_states = self.merger(hidden_states)
|
||||
@ -474,9 +475,33 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
)
|
||||
return position_ids
|
||||
|
||||
def forward(
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
|
||||
return image_embeds
|
||||
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# apply the visual model to the pixel values if they are provided
|
||||
if vision_embeds is not None:
|
||||
mask = torch.where(input_ids == self.image_token_id)
|
||||
inputs_embeds[mask] = vision_embeds
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -484,26 +509,10 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
seqlen: Seqlen,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
lm_head_indices: Optional[torch.Tensor],
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
pixel_attention_mask=None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
image_indices=None,
|
||||
):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# apply the visual model to the pixel values if they are provided
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.visual(
|
||||
pixel_values, grid_thw=image_grid_thw
|
||||
).squeeze(0)
|
||||
mask = torch.where(input_ids == self.image_token_id)
|
||||
inputs_embeds[mask] = image_embeds
|
||||
|
||||
hidden_states = self.text_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
|
@ -153,19 +153,14 @@ def prepare_for_decode(
|
||||
block_list_device = _async_h2d_tensor_copy(block_list)
|
||||
block_groups_device = _async_h2d_tensor_copy(block_groups)
|
||||
block_usage_device = _async_h2d_tensor_copy(block_usage)
|
||||
block_mapping = torch.nn.functional.one_hot(
|
||||
block_groups_device, num_classes=batch_size
|
||||
)
|
||||
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
|
||||
mask = mask >= block_usage_device.unsqueeze(-1)
|
||||
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
|
||||
|
||||
return trim_attn_metadata(
|
||||
HPUPagedAttentionMetadata(
|
||||
block_list=block_list_device,
|
||||
block_groups=block_groups_device,
|
||||
block_usage=block_usage_device,
|
||||
block_mapping=block_mapping.to(dtype),
|
||||
attn_bias=attn_bias,
|
||||
block_mapping=None,
|
||||
attn_bias=None,
|
||||
)
|
||||
)
|
||||
|
||||
@ -428,10 +423,8 @@ class FlashCausalLMBatch(Batch):
|
||||
for i, input_ids in enumerate(all_input_ids):
|
||||
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
||||
|
||||
# Create tensors on device
|
||||
all_input_ids_tensor = torch.tensor(
|
||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||
)
|
||||
# put on cpu temporarily, move to hpu in prepare_for_prefill
|
||||
all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64)
|
||||
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64)
|
||||
|
||||
@ -701,7 +694,9 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
||||
def concatenate(
|
||||
cls, batches: List["FlashCausalLMBatch"], padded_total_bs: int = 0
|
||||
) -> "FlashCausalLMBatch":
|
||||
# Batch attributes
|
||||
requests = []
|
||||
requests_idx_mapping = {}
|
||||
@ -750,7 +745,10 @@ class FlashCausalLMBatch(Batch):
|
||||
adapter_meta = None
|
||||
adapter_segment_builder = None
|
||||
else:
|
||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||
if padded_total_bs == batches[0].input_ids.shape[0]:
|
||||
input_ids = batches[0].input_ids
|
||||
else:
|
||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||
if (
|
||||
batches[0].position_ids is not None
|
||||
and batches[0].position_ids.dim() == 2
|
||||
@ -784,9 +782,7 @@ class FlashCausalLMBatch(Batch):
|
||||
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
||||
(total_batch_size, max_blocks)
|
||||
)
|
||||
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
||||
(total_batch_size, max_length)
|
||||
)
|
||||
all_input_ids_tensor = batches[0].all_input_ids_tensor
|
||||
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||
total_batch_size,
|
||||
)
|
||||
@ -829,9 +825,12 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
index = torch.tensor(list(range(start_index, end_index)), device="cpu")
|
||||
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
|
||||
all_input_ids_tensor[
|
||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||
] = batch.all_input_ids_tensor[:valid_bsize, :max_length]
|
||||
if i > 0:
|
||||
all_input_ids_tensor.index_copy_(
|
||||
0,
|
||||
index.to(batch.all_input_ids_tensor.device),
|
||||
batch.all_input_ids_tensor[:valid_bsize, :],
|
||||
)
|
||||
|
||||
block_tables_tensor[
|
||||
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
||||
@ -851,9 +850,10 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
if not prefilling:
|
||||
input_ids.index_copy_(
|
||||
0, index.to(input_ids.device), batch.input_ids[:valid_bsize]
|
||||
)
|
||||
if padded_total_bs != batches[0].input_ids.shape[0] or i > 0:
|
||||
input_ids.index_copy_(
|
||||
0, index.to(input_ids.device), batch.input_ids[:valid_bsize]
|
||||
)
|
||||
position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize])
|
||||
slot_indices.index_copy_(
|
||||
0, index, batch.slot_indices + cumulative_slots
|
||||
@ -987,7 +987,6 @@ class FlashCausalLMBatch(Batch):
|
||||
else:
|
||||
padded_bs = self.input_ids.shape[0]
|
||||
slots = self.slots[self.slot_indices]
|
||||
extra_pad = padded_bs - self.input_ids.shape[0]
|
||||
|
||||
self.hpu_attn_meta = prepare_for_decode(
|
||||
dtype,
|
||||
@ -998,17 +997,29 @@ class FlashCausalLMBatch(Batch):
|
||||
padded_bs,
|
||||
bucketing_ctx,
|
||||
)
|
||||
self.input_ids = F.pad(self.input_ids, (0, extra_pad), value=0)
|
||||
self.position_ids = F.pad(self.position_ids, (0, extra_pad), value=1)
|
||||
self.input_ids = F.pad(
|
||||
self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=0
|
||||
)
|
||||
|
||||
if self.position_ids.dim() == 2:
|
||||
# Qwen VL case
|
||||
self.position_ids = F.pad(
|
||||
self.position_ids,
|
||||
(0, 0, 0, padded_bs - self.position_ids.shape[0]),
|
||||
value=1,
|
||||
)
|
||||
else:
|
||||
self.position_ids = F.pad(
|
||||
self.position_ids, (0, padded_bs - self.position_ids.shape[0]), value=1
|
||||
)
|
||||
self.input_lengths_tensor = F.pad(
|
||||
self.input_lengths_tensor, (0, extra_pad), value=0
|
||||
self.input_lengths_tensor,
|
||||
(0, padded_bs - self.input_lengths_tensor.shape[0]),
|
||||
value=0,
|
||||
)
|
||||
self.cache_lengths_tensor = F.pad(
|
||||
self.cache_lengths_tensor, (0, extra_pad), value=0
|
||||
)
|
||||
self.all_input_ids_tensor = F.pad(
|
||||
self.all_input_ids_tensor,
|
||||
(0, 0, 0, extra_pad),
|
||||
self.cache_lengths_tensor,
|
||||
(0, padded_bs - self.cache_lengths_tensor.shape[0]),
|
||||
value=0,
|
||||
)
|
||||
next_token_chooser_parameters = []
|
||||
@ -1028,7 +1039,9 @@ class FlashCausalLMBatch(Batch):
|
||||
fsm_grammar_states,
|
||||
)
|
||||
|
||||
def prepare_for_prefill(self, max_padded_input_len, max_padded_bs):
|
||||
def prepare_for_prefill(
|
||||
self, max_padded_input_len, max_padded_bs, max_total_tokens
|
||||
):
|
||||
# Prepare values if we need to continue prefilling
|
||||
# Speculation must be ignored while we prefill even with chunking
|
||||
# it simplifies everything
|
||||
@ -1044,7 +1057,7 @@ class FlashCausalLMBatch(Batch):
|
||||
# need extra pad to match warmup seq
|
||||
extra_pad = max_padded_input_len - self.max_input_length
|
||||
extra_pad_bs = max_padded_bs - len(self)
|
||||
device = self.all_input_ids_tensor.device
|
||||
device = "hpu"
|
||||
if isinstance(self.input_ids, list) and len(self) > 1:
|
||||
input_ids_padded_length = []
|
||||
input_ids = []
|
||||
@ -1062,8 +1075,19 @@ class FlashCausalLMBatch(Batch):
|
||||
input_ids = [0] * extra_pad + input_ids
|
||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
else:
|
||||
self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0)
|
||||
input_ids_padded_length.extend([extra_pad] * len(self))
|
||||
input_ids = self.input_ids.new_zeros(max_padded_input_len * len(self))
|
||||
src_pos = 0
|
||||
for i in range(len(self)):
|
||||
end_pos = (i + 1) * max_padded_input_len
|
||||
start_pos = end_pos - self.input_lengths[i]
|
||||
input_ids[start_pos:end_pos] = self.input_ids[
|
||||
src_pos : src_pos + self.input_lengths[i]
|
||||
]
|
||||
input_ids_padded_length.append(
|
||||
max_padded_input_len - self.input_lengths[i]
|
||||
)
|
||||
src_pos += self.input_lengths[i]
|
||||
self.input_ids = input_ids
|
||||
|
||||
self.input_ids = F.pad(
|
||||
self.input_ids, (0, extra_pad_bs * max_padded_input_len), value=0
|
||||
@ -1288,12 +1312,17 @@ class FlashCausalLMBatch(Batch):
|
||||
self.prefill_next_token_indices = (
|
||||
self.prefill_next_token_indices + input_ids_padded_length_tensor
|
||||
)
|
||||
|
||||
self.all_input_ids_tensor = F.pad(
|
||||
self.all_input_ids_tensor,
|
||||
(0, 0, 0, extra_pad_bs),
|
||||
value=0,
|
||||
all_input_ids_tensor = torch.zeros(
|
||||
(max_padded_bs, max(max_total_tokens, self.all_input_ids_tensor.shape[-1])),
|
||||
dtype=torch.int64,
|
||||
device="hpu",
|
||||
)
|
||||
for i in range(len(self)):
|
||||
all_input_ids_tensor[i, : self.all_input_ids_tensor.shape[-1]] = (
|
||||
self.all_input_ids_tensor[i]
|
||||
)
|
||||
self.all_input_ids_tensor = all_input_ids_tensor
|
||||
|
||||
next_token_chooser_parameters = []
|
||||
next_token_chooser_parameters.extend([r.parameters for r in self.requests])
|
||||
pad_next_token_chooser_parameters(next_token_chooser_parameters, max_padded_bs)
|
||||
@ -1448,7 +1477,7 @@ class FlashCausalLM(Model):
|
||||
if head_size is None:
|
||||
# Some models use GQA and different sizes for o_proj
|
||||
# and q_proj, that allows for that.
|
||||
if hasattr(config, "head_dim"):
|
||||
if getattr(config, "head_dim", None) is not None:
|
||||
self.head_size = config.head_dim
|
||||
else:
|
||||
self.head_size = config.hidden_size // config.num_attention_heads
|
||||
@ -1459,6 +1488,8 @@ class FlashCausalLM(Model):
|
||||
self.kv_cache = []
|
||||
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
|
||||
self.bucketing_ctx = None
|
||||
self.max_total_tokens = None
|
||||
self.max_input_tokens = None
|
||||
htorch.core.hpu_set_env()
|
||||
if htorch.utils.internal.is_lazy():
|
||||
htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||
@ -1564,6 +1595,14 @@ class FlashCausalLM(Model):
|
||||
logger.info,
|
||||
f"Free memory on device {self.device}: {format_bytes(free_memory)} used_for_graph: {format_bytes(mem_used_from_graph)} ratio {graph_reserved_mem} reserved_for_runtime: {format_bytes(self.mem_reserved)}",
|
||||
)
|
||||
if max_total_tokens is None:
|
||||
max_total_tokens = sum(batch.input_lengths)
|
||||
|
||||
if max_input_tokens is None:
|
||||
max_input_tokens = max_total_tokens - 1
|
||||
|
||||
self.max_total_tokens = max_total_tokens
|
||||
self.max_input_tokens = max_input_tokens
|
||||
try:
|
||||
self.init_kv_cache(
|
||||
batch.num_blocks,
|
||||
@ -1597,11 +1636,6 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
|
||||
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
|
||||
if max_total_tokens is None:
|
||||
max_total_tokens = sum(batch.input_lengths)
|
||||
|
||||
if max_input_tokens is None:
|
||||
max_input_tokens = max_total_tokens - 1
|
||||
|
||||
self.kv_cache = []
|
||||
empty_cache()
|
||||
@ -2017,7 +2051,9 @@ class FlashCausalLM(Model):
|
||||
accepted_ids,
|
||||
speculative_ids,
|
||||
) = batch.next_token_chooser(
|
||||
batch.all_input_ids_tensor[:, : batch.max_current_length],
|
||||
batch.all_input_ids_tensor[
|
||||
: batch.next_token_logits.shape[0], : batch.max_current_length
|
||||
],
|
||||
batch.next_token_logits,
|
||||
speculate,
|
||||
batch.speculative_ids,
|
||||
@ -2031,14 +2067,29 @@ class FlashCausalLM(Model):
|
||||
accepted_ids,
|
||||
)
|
||||
if batch.valid_indices is not None:
|
||||
next_token_logprobs = next_token_logprobs.cpu()
|
||||
accepted_ids = accepted_ids.cpu()
|
||||
batch.all_input_ids_tensor = batch.all_input_ids_tensor[
|
||||
batch.valid_indices
|
||||
]
|
||||
next_input_ids = next_input_ids[batch.valid_indices]
|
||||
next_token_logprobs = next_token_logprobs[batch.valid_indices]
|
||||
accepted_ids = accepted_ids[batch.valid_indices]
|
||||
# TODO speculative decoding handling missing
|
||||
index = torch.arange(
|
||||
0,
|
||||
len(batch.valid_indices),
|
||||
device=batch.all_input_ids_tensor.device,
|
||||
)
|
||||
batch.all_input_ids_tensor.index_copy_(
|
||||
0, index, batch.all_input_ids_tensor[batch.valid_indices]
|
||||
)
|
||||
padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size(
|
||||
len(batch.valid_indices)
|
||||
)
|
||||
next_input_ids.index_copy_(
|
||||
0, index, next_input_ids[batch.valid_indices]
|
||||
)
|
||||
next_input_ids = next_input_ids[:padded_total_bs]
|
||||
|
||||
next_token_logprobs.index_copy_(
|
||||
0, index, next_token_logprobs[batch.valid_indices]
|
||||
)
|
||||
accepted_ids.index_copy_(
|
||||
0, index, accepted_ids[batch.valid_indices]
|
||||
)
|
||||
if speculative_ids is not None:
|
||||
speculative_ids = speculative_ids[batch.valid_indices]
|
||||
batch.top_n_tokens_tensor = batch.top_n_tokens_tensor[
|
||||
@ -2106,10 +2157,13 @@ class FlashCausalLM(Model):
|
||||
batch.slot_indices += accepted_ids[: len(batch)]
|
||||
else:
|
||||
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
|
||||
index = F.pad(
|
||||
index, (0, next_input_ids.shape[0] - index.shape[0]), value=0
|
||||
)
|
||||
index = index.to(batch.all_input_ids_tensor.device)
|
||||
batch_idx = torch.arange(
|
||||
0,
|
||||
batch.all_input_ids_tensor.shape[0],
|
||||
index.shape[0],
|
||||
dtype=torch.long,
|
||||
device=batch.all_input_ids_tensor.device,
|
||||
)
|
||||
@ -2197,7 +2251,18 @@ class FlashCausalLM(Model):
|
||||
htorch.core.mark_step()
|
||||
# Stage 2. Prepare new batch for speculative scheduling
|
||||
if len(batches) > 1:
|
||||
batch = self.batch_type.concatenate(batches)
|
||||
if self.bucketing_ctx is not None:
|
||||
total_batch_size = 0
|
||||
for b in batches:
|
||||
total_batch_size += len(b)
|
||||
padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size(
|
||||
total_batch_size
|
||||
)
|
||||
batch = self.batch_type.concatenate(
|
||||
batches, padded_total_bs=padded_total_bs
|
||||
)
|
||||
else:
|
||||
batch = self.batch_type.concatenate(batches)
|
||||
else:
|
||||
batch = batches[0]
|
||||
prefill = batch.prefilling
|
||||
@ -2208,13 +2273,18 @@ class FlashCausalLM(Model):
|
||||
batch.max_input_length
|
||||
),
|
||||
self.bucketing_ctx.get_padded_prompt_batch_size(len(batch)),
|
||||
self.max_total_tokens,
|
||||
)
|
||||
else:
|
||||
batch.prepare_for_prefill(batch.max_input_length, len(batch))
|
||||
batch.prepare_for_prefill(
|
||||
batch.max_input_length, len(batch), self.max_total_tokens
|
||||
)
|
||||
else:
|
||||
batch.prepare_for_decode(
|
||||
self.dtype, self.use_contiguous_pa, self.bucketing_ctx
|
||||
)
|
||||
if hasattr(self, "set_inputs_embeds") and callable(self.set_inputs_embeds):
|
||||
self.set_inputs_embeds(batch)
|
||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||
# Update adapter indices for speculative tokens (if present)
|
||||
adapter_meta = batch.adapter_meta
|
||||
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||
|
||||
@ -119,17 +119,17 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
def image_text_replacement(processor, image_input, config, image_id: int) -> str:
|
||||
def image_text_replacement(processor, image_input, config) -> str:
|
||||
if config.model_type == "idefics2":
|
||||
image_seq_len = 64
|
||||
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
|
||||
if processor.image_processor.do_image_splitting:
|
||||
image_str *= 5
|
||||
return image_str
|
||||
return image_str, IDEFICS2_FAKE_TOKEN
|
||||
if config.model_type == "idefics3":
|
||||
# TODO: implement this in a more general way
|
||||
n_rows = image_input["rows"][0][image_id]
|
||||
n_cols = image_input["cols"][0][image_id]
|
||||
n_rows = image_input["rows"][0][0]
|
||||
n_cols = image_input["cols"][0][0]
|
||||
image_seq_len = int(
|
||||
((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
|
||||
/ (config.scale_factor**2)
|
||||
@ -142,41 +142,41 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
||||
image_token=IDEFICS3_IMAGE_TOKEN,
|
||||
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
|
||||
)
|
||||
return image_str
|
||||
return image_str, IDEFICS3_FAKE_IMAGE_TOKEN
|
||||
elif config.model_type == "llava_next":
|
||||
height, width = image_input["image_sizes"][image_id]
|
||||
height, width = image_input["image_sizes"][0]
|
||||
num_features = get_number_of_features(height, width, config)
|
||||
|
||||
log_master(
|
||||
logger.info,
|
||||
f"Found {num_features} features in image of resolution {height}x{width}",
|
||||
)
|
||||
return "<image>" * num_features
|
||||
return "<image>" * num_features, "<image>"
|
||||
|
||||
elif config.model_type == "paligemma":
|
||||
return "<image>" * config.text_config.num_image_tokens
|
||||
return "<image>" * config.text_config.num_image_tokens, "<image>"
|
||||
elif config.model_type == "qwen2_vl":
|
||||
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
|
||||
grid_t, grid_h, grid_w = image_input["image_grid_thw"][0]
|
||||
num_pads = grid_t * grid_h * grid_w // 4
|
||||
padding = "<|image_pad|>" * num_pads
|
||||
return f"<|vision_start|>{padding}<|vision_end|>"
|
||||
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
|
||||
elif config.model_type == "qwen2_5_vl":
|
||||
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
|
||||
grid_t, grid_h, grid_w = image_input["image_grid_thw"][0]
|
||||
num_pads = grid_t * grid_h * grid_w // 4
|
||||
padding = "<|image_pad|>" * num_pads
|
||||
return f"<|vision_start|>{padding}<|vision_end|>"
|
||||
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
|
||||
elif config.model_type == "gemma3":
|
||||
# TODO: get correct number of features via reviewing the Gemma3 architecture
|
||||
# and calculating the number of image tokens
|
||||
num_pads = 256
|
||||
padding = "<image_soft_token>" * num_pads
|
||||
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n"
|
||||
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n", "<start_of_image>"
|
||||
elif config.model_type == "llama4":
|
||||
patch_size = config.vision_config.patch_size
|
||||
pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio
|
||||
downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
|
||||
aspect_ratios = image_input["aspect_ratios"][image_id]
|
||||
image_height, image_width = image_input["pixel_values"][image_id].shape[-2:]
|
||||
aspect_ratios = image_input["aspect_ratios"][0]
|
||||
image_height, image_width = image_input["pixel_values"][0].shape[-2:]
|
||||
|
||||
num_patches_per_chunk = int(
|
||||
(image_height // patch_size)
|
||||
@ -187,7 +187,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
||||
aspect_ratios, num_patches_per_chunk
|
||||
)
|
||||
|
||||
return tokens_for_this_image
|
||||
return tokens_for_this_image, "<|image_start|>"
|
||||
else:
|
||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||
|
||||
@ -200,6 +200,27 @@ def image_text_replacement_fixup(config, text: str) -> str:
|
||||
return text
|
||||
|
||||
|
||||
def preprocess_text(config, text: str) -> str:
|
||||
if config.model_type == "paligemma":
|
||||
return "<bos>" + text + "\n"
|
||||
return text
|
||||
|
||||
|
||||
def preprocess_image(config, img):
|
||||
model_type = config.model_type
|
||||
|
||||
if model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20:
|
||||
img = img.resize((img.width * 2, img.height * 2))
|
||||
if model_type == "paligemma":
|
||||
img = img.convert("RGB")
|
||||
|
||||
if model_type not in {"llava_next", "gemma3", "llama4"}:
|
||||
# TODO: check if this is needed
|
||||
img = [img]
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def get_unpadded_features(
|
||||
original_height: int,
|
||||
original_width: int,
|
||||
@ -254,105 +275,259 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
||||
return unpadded_features + newline_features + base_features
|
||||
|
||||
|
||||
def scatter_image_embeds(
|
||||
embeds: torch.Tensor, is_embed: Optional[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
if is_embed is None:
|
||||
return embeds
|
||||
|
||||
placeholders = embeds.new_full(
|
||||
(is_embed.shape[0], embeds.shape[-1]),
|
||||
fill_value=torch.nan,
|
||||
)
|
||||
placeholders[is_embed.to(embeds.device)] = embeds
|
||||
return placeholders
|
||||
|
||||
|
||||
def gather_image_embeds(
|
||||
embeds: torch.Tensor, is_embed: Optional[torch.Tensor]
|
||||
) -> Optional[torch.Tensor]:
|
||||
if is_embed is None:
|
||||
return embeds
|
||||
sel = embeds[is_embed.to(embeds.device)]
|
||||
return sel if sel.numel() else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImagePositions:
|
||||
offset: int
|
||||
length: int
|
||||
id: int
|
||||
num_placeholder_tokens: int
|
||||
is_embed: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class FlashVlmCausalLMBatch(FlashCausalLMBatch):
|
||||
image_inputs: Optional[List[List[Dict[str, torch.Tensor]]]]
|
||||
image_positions: Optional[List[List[ImagePositions]]]
|
||||
encoder_cache: Optional[List[Dict[int, torch.Tensor]]]
|
||||
pixel_values: Optional[List[torch.Tensor]]
|
||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||
image_sizes: Optional[List[Tuple[int, int]]]
|
||||
image_grid_thw: Optional[torch.Tensor]
|
||||
cache_entries_to_free: List[Tuple[int, int]]
|
||||
has_image_inputs: bool = False
|
||||
inputs_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches):
|
||||
batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches)
|
||||
def concatenate(cls, batches, padded_total_bs: int = 0):
|
||||
batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs)
|
||||
batch.image_inputs = []
|
||||
batch.image_positions = []
|
||||
batch.encoder_cache = []
|
||||
for b in batches:
|
||||
if b.image_inputs is not None:
|
||||
batch.image_inputs.extend(b.image_inputs)
|
||||
else:
|
||||
batch.image_inputs.append(None)
|
||||
if b.image_positions is not None:
|
||||
batch.image_positions.extend(b.image_positions)
|
||||
else:
|
||||
batch.image_positions.append(None)
|
||||
if b.encoder_cache is not None:
|
||||
batch.encoder_cache.extend(b.encoder_cache)
|
||||
else:
|
||||
batch.encoder_cache.append(None)
|
||||
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
batch.image_grid_thw = None
|
||||
batch.inputs_embeds = None
|
||||
# To be filled in prepare_for_prefill
|
||||
batch.has_image_inputs = False
|
||||
batch.cache_entries_to_free = []
|
||||
return batch
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int]):
|
||||
if len(request_ids) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
|
||||
image_inputs = []
|
||||
image_positions = []
|
||||
encoder_cache = []
|
||||
|
||||
for request_id in request_ids:
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
image_inputs.append(self.image_inputs[idx])
|
||||
image_positions.append(self.image_positions[idx])
|
||||
encoder_cache.append(self.encoder_cache[idx])
|
||||
|
||||
batch = super().filter(request_ids)
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
batch.image_grid_thw = None
|
||||
batch.inputs_embeds = None
|
||||
batch.image_inputs = image_inputs
|
||||
batch.image_positions = image_positions
|
||||
batch.encoder_cache = encoder_cache
|
||||
|
||||
# To be filled in prepare_for_prefill
|
||||
batch.has_image_inputs = False
|
||||
batch.cache_entries_to_free = []
|
||||
return batch
|
||||
|
||||
@classmethod
|
||||
def batch_tokenized_inputs(
|
||||
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
|
||||
):
|
||||
# Process images first. We need all of them so that the processor
|
||||
# can make the image splits the same size. And we need the final
|
||||
# sizes to insert correct number of image tokens.
|
||||
images = []
|
||||
kwargs = {}
|
||||
if (
|
||||
hasattr(processor, "image_processor_class")
|
||||
and processor.image_processor_class == "Idefics3ImageProcessor"
|
||||
):
|
||||
kwargs["return_row_col_info"] = True
|
||||
|
||||
max_length = 0
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
if not hasattr(config, "image_token_index"):
|
||||
config.image_token_index = config.image_token_id
|
||||
|
||||
batch_tokenized_inputs: List[List[int]] = []
|
||||
batch_image_inputs: List[Optional[List[dict]]] = []
|
||||
batch_image_positions: List[Optional[List[ImagePositions]]] = []
|
||||
|
||||
for r in requests:
|
||||
text_parts = []
|
||||
image_inputs = []
|
||||
image_texts = []
|
||||
|
||||
image_id = 0
|
||||
|
||||
for chunk in r.input_chunks.chunks:
|
||||
chunk_type = chunk.WhichOneof("chunk")
|
||||
if chunk_type == "text":
|
||||
pass
|
||||
text = preprocess_text(config, chunk.text)
|
||||
text_parts.append(text)
|
||||
elif chunk_type == "image":
|
||||
image = Image.open(BytesIO(chunk.image.data))
|
||||
# qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the
|
||||
# default warmup image is 20x20
|
||||
if config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
|
||||
if image.width <= 20:
|
||||
w = image.width * 2
|
||||
h = image.height * 2
|
||||
image = image.resize((w, h))
|
||||
img = Image.open(BytesIO(chunk.image.data))
|
||||
img = preprocess_image(config, img)
|
||||
|
||||
if config.model_type == "llava_next":
|
||||
images.append(image)
|
||||
elif config.model_type == "gemma3":
|
||||
images.append(image)
|
||||
elif config.model_type == "llama4":
|
||||
images.append(image)
|
||||
else:
|
||||
images.append([image])
|
||||
image_input = processor.image_processor(
|
||||
[img], return_tensors="pt", **kwargs
|
||||
)
|
||||
image_inputs.append(image_input)
|
||||
|
||||
img_text, img_start_token_str = image_text_replacement(
|
||||
processor, image_input, config
|
||||
)
|
||||
text_parts.append(img_text)
|
||||
|
||||
image_texts.append([image_id, img_start_token_str, img_text])
|
||||
image_id += 1
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||
|
||||
if images:
|
||||
kwargs = {}
|
||||
if (
|
||||
hasattr(processor, "image_processor_class")
|
||||
and processor.image_processor_class == "Idefics3ImageProcessor"
|
||||
):
|
||||
kwargs["return_row_col_info"] = True
|
||||
|
||||
image_inputs = processor.image_processor(
|
||||
images, return_tensors="pt", **kwargs
|
||||
)
|
||||
else:
|
||||
image_inputs = None
|
||||
|
||||
batch_tokenized_inputs = []
|
||||
max_length = 0
|
||||
image_id = 0
|
||||
for r in requests:
|
||||
full_text = ""
|
||||
for chunk in r.input_chunks.chunks:
|
||||
chunk_type = chunk.WhichOneof("chunk")
|
||||
if chunk_type == "text":
|
||||
full_text += chunk.text
|
||||
elif chunk_type == "image":
|
||||
full_text += image_text_replacement(
|
||||
processor, image_inputs, config, image_id
|
||||
)
|
||||
image_id += 1
|
||||
|
||||
full_text = image_text_replacement_fixup(config, full_text)
|
||||
full_text = image_text_replacement_fixup(config, "".join(text_parts))
|
||||
input_ids = tokenizer(
|
||||
full_text,
|
||||
truncation=True,
|
||||
max_length=r.truncate,
|
||||
add_special_tokens=r.add_special_tokens,
|
||||
add_special_tokens=(
|
||||
r.add_special_tokens if config.model_type != "paligemma" else False
|
||||
),
|
||||
)["input_ids"]
|
||||
max_length = max(max_length, len(input_ids))
|
||||
batch_tokenized_inputs.append(input_ids)
|
||||
|
||||
return batch_tokenized_inputs, image_inputs
|
||||
if len(image_inputs) > 0:
|
||||
img_start_token = vocab[image_texts[0][1]]
|
||||
image_positions = cls.get_image_positions(
|
||||
input_ids, image_texts, img_start_token, config, tokenizer
|
||||
)
|
||||
else:
|
||||
image_inputs = None
|
||||
image_positions = None
|
||||
|
||||
batch_tokenized_inputs.append(input_ids)
|
||||
batch_image_inputs.append(image_inputs)
|
||||
batch_image_positions.append(image_positions)
|
||||
|
||||
return batch_tokenized_inputs, batch_image_inputs, batch_image_positions
|
||||
|
||||
@classmethod
|
||||
def get_image_positions(
|
||||
cls,
|
||||
input_ids: List[int],
|
||||
image_texts: List[Tuple[int, str, str]],
|
||||
img_start_token: int,
|
||||
config,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> List[ImagePositions]:
|
||||
image_positions = []
|
||||
num_images = len(image_texts)
|
||||
|
||||
input_ids_t = torch.as_tensor(input_ids)
|
||||
img_start_token_pos = torch.where(input_ids_t.eq(img_start_token))[0]
|
||||
num_tokens = input_ids_t.numel()
|
||||
|
||||
last_pos = 0
|
||||
for i in range(num_images):
|
||||
image_id, img_start_token_str, img_text = image_texts[i]
|
||||
img_text = image_text_replacement_fixup(config, img_text)
|
||||
|
||||
if config.model_type == "gemma3":
|
||||
img_text = img_text.replace("\n\n", "")
|
||||
|
||||
tokens = tokenizer(img_text, add_special_tokens=False, return_tensors="pt")[
|
||||
"input_ids"
|
||||
][0]
|
||||
length = tokens.numel()
|
||||
|
||||
assert (
|
||||
length <= num_tokens
|
||||
), f"{length} > {num_tokens} Image is truncated, try increasing --max-batch-prefill-tokens"
|
||||
|
||||
pos = torch.searchsorted(img_start_token_pos, last_pos, right=False)
|
||||
index = img_start_token_pos[pos]
|
||||
assert torch.equal(
|
||||
input_ids_t[index : index + length], tokens
|
||||
), "Image tokens not found in input_ids"
|
||||
|
||||
is_embed = tokens == config.image_token_index
|
||||
num_placeholder_tokens = int(is_embed.sum())
|
||||
if num_placeholder_tokens == length:
|
||||
is_embed = None
|
||||
|
||||
pos = ImagePositions(
|
||||
offset=index,
|
||||
length=length,
|
||||
id=image_id,
|
||||
num_placeholder_tokens=num_placeholder_tokens,
|
||||
is_embed=is_embed,
|
||||
)
|
||||
|
||||
image_positions.append(pos)
|
||||
last_pos = index + length
|
||||
|
||||
if (
|
||||
config.model_type == "idefics2"
|
||||
and i + 1 != num_images
|
||||
and input_ids[last_pos] == config.image_token_index
|
||||
):
|
||||
fake_token = last_pos - 1
|
||||
fake_token_index = torch.searchsorted(
|
||||
img_start_token_pos, fake_token, right=False
|
||||
)
|
||||
img_start_token_pos[fake_token_index] = last_pos
|
||||
image_texts[i + 1][2] = image_texts[i + 1][2][
|
||||
len(img_start_token_str) :
|
||||
]
|
||||
|
||||
return image_positions
|
||||
|
||||
@classmethod
|
||||
def from_pb_processor(
|
||||
@ -364,33 +539,164 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch):
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashVlmCausalLMBatch":
|
||||
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
||||
pb.requests, tokenizer, processor, config
|
||||
batch_tokenized_inputs, image_inputs, image_positions = (
|
||||
cls.batch_tokenized_inputs(pb.requests, tokenizer, processor, config)
|
||||
)
|
||||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||
if image_inputs is not None:
|
||||
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
||||
if "pixel_attention_mask" in image_inputs:
|
||||
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
|
||||
device=device
|
||||
)
|
||||
else:
|
||||
batch.pixel_attention_mask = None
|
||||
if "image_sizes" in image_inputs:
|
||||
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
||||
else:
|
||||
batch.image_sizes = None
|
||||
if "image_grid_thw" in image_inputs:
|
||||
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
|
||||
else:
|
||||
batch.image_grid_thw = None
|
||||
else:
|
||||
batch.image_inputs = image_inputs
|
||||
batch.image_positions = image_positions
|
||||
batch.encoder_cache = [{} for _ in range(len(pb.requests))]
|
||||
if len(image_inputs):
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
batch.image_grid_thw = None
|
||||
return batch
|
||||
|
||||
def prepare_for_prefill(
|
||||
self, max_padded_input_len, max_padded_bs, max_total_tokens
|
||||
):
|
||||
super().prepare_for_prefill(
|
||||
max_padded_input_len, max_padded_bs, max_total_tokens
|
||||
)
|
||||
|
||||
self.has_image_inputs = False
|
||||
self.cache_entries_to_free = []
|
||||
|
||||
self.pixel_values = []
|
||||
|
||||
assert (
|
||||
len(self.cache_lengths)
|
||||
== len(self.input_lengths)
|
||||
== len(self.prefilling_mask)
|
||||
), "Mismatch in lengths of cache_lengths, input_lengths, and prefilling_mask"
|
||||
|
||||
for i, (
|
||||
cache_length,
|
||||
input_length,
|
||||
request_prefilling,
|
||||
) in enumerate(
|
||||
zip(
|
||||
self.cache_lengths,
|
||||
self.input_lengths,
|
||||
self.prefilling_mask,
|
||||
)
|
||||
):
|
||||
if not request_prefilling or self.image_positions[i] is None:
|
||||
continue
|
||||
|
||||
for image_position in self.image_positions[i]:
|
||||
if image_position is None:
|
||||
continue
|
||||
start_pos = image_position.offset
|
||||
length = image_position.length
|
||||
|
||||
if start_pos >= cache_length + input_length:
|
||||
# No encoder input required at this step
|
||||
break
|
||||
if start_pos + length <= cache_length:
|
||||
# The encode input is already processed
|
||||
continue
|
||||
|
||||
self.has_image_inputs = True
|
||||
|
||||
if image_position.id not in self.encoder_cache[i]:
|
||||
image_inputs = self.image_inputs[i][image_position.id]
|
||||
self.pixel_values.append((i, image_position.id, image_inputs))
|
||||
|
||||
# Remove the image from the image_inputs
|
||||
self.image_inputs[i][image_position.id] = None
|
||||
|
||||
if not self.has_image_inputs:
|
||||
self.pixel_values = None
|
||||
self.pixel_attention_mask = None
|
||||
self.image_sizes = None
|
||||
self.image_grid_thw = None
|
||||
else:
|
||||
image_grid_thw_list = [
|
||||
x[2]["image_grid_thw"]
|
||||
for x in self.pixel_values
|
||||
if "image_grid_thw" in x[2]
|
||||
]
|
||||
if image_grid_thw_list:
|
||||
self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0)
|
||||
else:
|
||||
self.image_grid_thw = None
|
||||
|
||||
def update_encoder_cache(self, encoder_outputs, request_id, img_pos):
|
||||
self.encoder_cache[request_id][img_pos.id] = scatter_image_embeds(
|
||||
encoder_outputs, img_pos.is_embed
|
||||
)
|
||||
|
||||
def gather_vision_embeds(self):
|
||||
device = self.input_ids.device
|
||||
chunks = []
|
||||
for (
|
||||
i,
|
||||
cache_length,
|
||||
input_length,
|
||||
request_prefilling,
|
||||
) in zip(
|
||||
range(len(self.requests)),
|
||||
self.cache_lengths,
|
||||
self.input_lengths,
|
||||
self.prefilling_mask,
|
||||
):
|
||||
if not request_prefilling or self.image_positions[i] is None:
|
||||
continue
|
||||
|
||||
for image_position in self.image_positions[i]:
|
||||
if image_position is None:
|
||||
continue
|
||||
start_pos = image_position.offset
|
||||
length = image_position.length
|
||||
|
||||
if start_pos >= cache_length + input_length:
|
||||
# No encoder input required at this step
|
||||
break
|
||||
if start_pos + length <= cache_length:
|
||||
# The encode input is already processed
|
||||
continue
|
||||
|
||||
start_idx = max(cache_length - start_pos, 0)
|
||||
end_idx = min(cache_length - start_pos + input_length, length)
|
||||
|
||||
assert (
|
||||
image_position.id in self.encoder_cache[i]
|
||||
), f"image_id {image_position.id} not in encoder_cache {self.encoder_cache[i]}"
|
||||
encoder_output = self.encoder_cache[i][image_position.id]
|
||||
|
||||
is_embed = image_position.is_embed
|
||||
if is_embed is not None:
|
||||
is_embed = is_embed[start_idx:end_idx]
|
||||
|
||||
from loguru import logger
|
||||
|
||||
logger.info(
|
||||
f"image_id {image_position.id} start_idx {start_idx} end_idx {end_idx}, length {length}"
|
||||
)
|
||||
|
||||
embeds = gather_image_embeds(
|
||||
encoder_output[start_idx:end_idx],
|
||||
is_embed=is_embed,
|
||||
)
|
||||
if embeds is not None:
|
||||
chunks.append(embeds)
|
||||
|
||||
if end_idx == length:
|
||||
self.cache_entries_to_free.append((i, image_position.id))
|
||||
self.image_positions[i][image_position.id] = None
|
||||
|
||||
if len(chunks) == 0:
|
||||
return None
|
||||
return torch.cat(chunks, dim=0).to(device)
|
||||
|
||||
def free_encoder_cache(self):
|
||||
for i, image_id in self.cache_entries_to_free:
|
||||
self.encoder_cache[i].pop(image_id, None)
|
||||
|
||||
self.cache_entries_to_free = []
|
||||
|
||||
|
||||
class FlashVlmCausalLM(FlashCausalLM):
|
||||
def __init__(
|
||||
@ -402,6 +708,7 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
batch_class=FlashVlmCausalLMBatch,
|
||||
revision,
|
||||
trust_remote_code: bool,
|
||||
support_chunking: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
if PREFIX_CACHING:
|
||||
@ -419,8 +726,7 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
# FIXME: VLM do not work with context chunking yet
|
||||
support_chunking=False,
|
||||
support_chunking=support_chunking,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -471,9 +777,12 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
bucketing_ctx=None,
|
||||
)
|
||||
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
|
||||
inputs_embeds = self.get_inputs_embeds(
|
||||
input_ids=input_ids.to(self.device),
|
||||
)
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
input_ids=_async_h2d_tensor_copy(input_ids),
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=self.kv_cache,
|
||||
@ -481,10 +790,7 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
seqlen=trim_seqlen_metadata(seqlen),
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
lm_head_indices=None,
|
||||
pixel_values=None,
|
||||
pixel_attention_mask=None,
|
||||
image_sizes=None,
|
||||
image_grid_thw=None,
|
||||
attention_mask=None,
|
||||
)
|
||||
|
||||
def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
|
||||
@ -546,6 +852,84 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
|
||||
)
|
||||
|
||||
def get_vision_embeds(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
pixel_attention_mask: torch.Tensor,
|
||||
image_sizes: torch.Tensor,
|
||||
image_grid_thw: torch.Tensor,
|
||||
):
|
||||
embeds = self.model.get_vision_embeds(
|
||||
pixel_values=pixel_values,
|
||||
pixel_attention_mask=pixel_attention_mask,
|
||||
image_sizes=image_sizes,
|
||||
image_grid_thw=image_grid_thw,
|
||||
)
|
||||
return embeds
|
||||
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: Optional[torch.Tensor] = None,
|
||||
):
|
||||
return self.model.get_inputs_embeds(
|
||||
input_ids=input_ids,
|
||||
vision_embeds=vision_embeds,
|
||||
)
|
||||
|
||||
def encode_images(self, batch):
|
||||
if batch.pixel_values is not None:
|
||||
device = batch.input_ids.device
|
||||
for request_id, image_id, image_input in batch.pixel_values:
|
||||
pixel_values = image_input["pixel_values"].to(device)
|
||||
|
||||
if "pixel_attention_mask" in image_input:
|
||||
pixel_attention_mask = image_input["pixel_attention_mask"].to(
|
||||
device
|
||||
)
|
||||
else:
|
||||
pixel_attention_mask = None
|
||||
|
||||
if "image_sizes" in image_input:
|
||||
image_sizes = image_input["image_sizes"].to(device)
|
||||
else:
|
||||
image_sizes = None
|
||||
|
||||
if "image_grid_thw" in image_input:
|
||||
image_grid_thw = image_input["image_grid_thw"]
|
||||
else:
|
||||
image_grid_thw = None
|
||||
|
||||
encoder_outputs = self.get_vision_embeds(
|
||||
pixel_values=pixel_values,
|
||||
pixel_attention_mask=pixel_attention_mask,
|
||||
image_sizes=image_sizes,
|
||||
image_grid_thw=image_grid_thw,
|
||||
)
|
||||
batch.update_encoder_cache(
|
||||
encoder_outputs,
|
||||
request_id,
|
||||
batch.image_positions[request_id][image_id],
|
||||
)
|
||||
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
|
||||
def set_inputs_embeds(self, batch):
|
||||
if batch.has_image_inputs:
|
||||
self.encode_images(batch)
|
||||
vision_embeds = batch.gather_vision_embeds()
|
||||
batch.has_image_inputs = False
|
||||
else:
|
||||
vision_embeds = None
|
||||
|
||||
inputs_embeds = self.get_inputs_embeds(
|
||||
batch.input_ids, vision_embeds=vision_embeds
|
||||
)
|
||||
|
||||
batch.inputs_embeds = inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: FlashVlmCausalLMBatch,
|
||||
@ -593,6 +977,7 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
position_ids = new_position_ids
|
||||
else:
|
||||
input_ids = batch.input_ids
|
||||
inputs_embeds = batch.inputs_embeds
|
||||
position_ids = batch.position_ids
|
||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||
kv_cache = self.kv_cache
|
||||
@ -605,10 +990,25 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
|
||||
if position_ids.dim() == 1 and batch.prefilling:
|
||||
position_ids = self.model.get_position_ids(
|
||||
input_ids, batch.image_grid_thw
|
||||
input_ids.cpu(), batch.image_grid_thw
|
||||
)
|
||||
batch.position_ids = position_ids
|
||||
|
||||
attention_mask = None
|
||||
attention_mask_forward = None
|
||||
if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None:
|
||||
attention_mask = self.model.get_attention_mask(
|
||||
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
|
||||
)
|
||||
min_dtype = torch.finfo(self.dtype).min
|
||||
attention_mask_forward = torch.where(attention_mask, 0, min_dtype).to(
|
||||
input_ids.device
|
||||
)
|
||||
attention_mask = attention_mask.reshape(-1)
|
||||
if self.model.config.model_type == "llama4":
|
||||
attention_mask = (input_ids != 0).long()
|
||||
attention_mask_forward = attention_mask.view(input_lengths.shape[0], -1)
|
||||
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||
# in a circular buffer mode.
|
||||
@ -639,7 +1039,7 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
input_lengths=_async_h2d_tensor_copy(input_lengths),
|
||||
)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=_async_h2d_tensor_copy(position_ids),
|
||||
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
|
||||
kv_cache=kv_cache,
|
||||
@ -647,18 +1047,11 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
seqlen=trim_seqlen_metadata(seqlen),
|
||||
hpu_attention_meta=batch.hpu_attn_meta,
|
||||
lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
|
||||
pixel_values=batch.pixel_values,
|
||||
pixel_attention_mask=batch.pixel_attention_mask,
|
||||
image_sizes=batch.image_sizes,
|
||||
image_grid_thw=batch.image_grid_thw,
|
||||
attention_mask=attention_mask_forward,
|
||||
**kwargs,
|
||||
)
|
||||
if batch.pixel_values is not None:
|
||||
batch.pixel_values = None
|
||||
if batch.pixel_attention_mask is not None:
|
||||
batch.pixel_attention_mask = None
|
||||
if batch.image_sizes is not None:
|
||||
batch.image_sizes = None
|
||||
if batch.image_grid_thw is not None:
|
||||
batch.image_grid_thw = None
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
batch.image_grid_thw = None
|
||||
batch.free_encoder_cache()
|
||||
return logits, speculative_logits
|
||||
|
@ -1,156 +0,0 @@
|
||||
import re
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
||||
from transformers import (
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import (
|
||||
NextTokenChooser,
|
||||
StoppingCriteria,
|
||||
)
|
||||
from text_generation_server.utils.chunks import concat_text_chunks
|
||||
|
||||
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
|
||||
|
||||
# we split individual characters inside special tokens like [START_DNA]
|
||||
CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")
|
||||
|
||||
# token added to implement a custom sequence tokenization. This token is added at
|
||||
# corpus cleaning step and removed in pretokenization. The digits are added to increase the chance
|
||||
# that they do not occur in the corpus. The digits are escaped so that the token does not appear
|
||||
# literally in the source code in case we ever include it in the training data.
|
||||
SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"
|
||||
|
||||
|
||||
def _insert_split_marker(m: re.Match):
|
||||
"""
|
||||
Applies split marker based on a regex match of special tokens such as
|
||||
[START_DNA].
|
||||
Parameters
|
||||
----------
|
||||
n : str
|
||||
Input text to split
|
||||
Returns
|
||||
----------
|
||||
str - the text with the split token added
|
||||
"""
|
||||
start_token, _, sequence, end_token = m.groups()
|
||||
sequence = re.sub(r"(.)", rf"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
|
||||
return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"
|
||||
|
||||
|
||||
def escape_custom_split_sequence(text):
|
||||
"""
|
||||
Applies custom splitting to the text for GALILEO's tokenization
|
||||
Parameters
|
||||
----------
|
||||
text : str
|
||||
Input text to split
|
||||
Returns
|
||||
----------
|
||||
str - the text with the split token added
|
||||
"""
|
||||
return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)
|
||||
|
||||
|
||||
# END CREDIT
|
||||
|
||||
|
||||
class GalacticaCausalLMBatch(CausalLMBatch):
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "GalacticaCausalLMBatch":
|
||||
inputs = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
prefix_offsets = []
|
||||
top_n_tokens = []
|
||||
read_offsets = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
# Parse batch
|
||||
max_truncation = 0
|
||||
padding_right_offset = 0
|
||||
max_decode_tokens = 0
|
||||
for i, r in enumerate(pb.requests):
|
||||
requests_idx_mapping[r.id] = i
|
||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||
inputs.append(
|
||||
escape_custom_split_sequence(concat_text_chunks(r.input_chunks.chunks))
|
||||
)
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||
)
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(r.top_n_tokens)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||
padding_right_offset = max(
|
||||
padding_right_offset, stopping_criteria.max_new_tokens
|
||||
)
|
||||
|
||||
tokenized_inputs = tokenizer(
|
||||
inputs,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
return_token_type_ids=False,
|
||||
truncation=True,
|
||||
max_length=max_truncation,
|
||||
).to(device)
|
||||
for _ in pb.requests:
|
||||
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||
prefix_offsets.append(0)
|
||||
read_offsets.append(input_len)
|
||||
|
||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||
max_input_length = input_lengths.max()
|
||||
|
||||
input_ids = tokenized_inputs["input_ids"]
|
||||
# Allocate maximum attention_mask
|
||||
attention_mask = input_ids.new_zeros(
|
||||
(pb.size, max_input_length + padding_right_offset)
|
||||
)
|
||||
# Copy tokenizer attention_mask into fully allocated attention_mask
|
||||
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
|
||||
|
||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
||||
top_n_tokens_tensor = torch.tensor(
|
||||
top_n_tokens, device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
all_input_ids=list(all_input_ids),
|
||||
input_lengths=input_lengths.tolist(),
|
||||
prefix_offsets=prefix_offsets,
|
||||
read_offsets=read_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
max_input_length=max_input_length.item(),
|
||||
padding_right_offset=padding_right_offset,
|
||||
max_tokens=max_tokens,
|
||||
)
|
@ -4,14 +4,14 @@ from loguru import logger
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
|
||||
ATTENTION = os.getenv("ATTENTION", "default")
|
||||
ATTENTION = os.getenv("ATTENTION", "paged")
|
||||
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
||||
PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in {
|
||||
"1",
|
||||
"true",
|
||||
}
|
||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||
_expected = {"paged", "default"}
|
||||
_expected = {"paged"}
|
||||
assert (
|
||||
ATTENTION in _expected
|
||||
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
||||
|
@ -1,882 +0,0 @@
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
import torch
|
||||
import time
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
PreTrainedTokenizerBase,
|
||||
ProcessorMixin,
|
||||
)
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
Tokens,
|
||||
Generation,
|
||||
GeneratedText,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
import torch.distributed
|
||||
from text_generation_server.models.custom_modeling.idefics_modeling import (
|
||||
IdeficsForVisionText2Text,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.quantization import get_loader
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IdeficsCausalLMBatch(Batch):
|
||||
batch_id: int
|
||||
requests: List[generate_pb2.Request]
|
||||
requests_idx_mapping: Dict[int, int]
|
||||
|
||||
# Decoder values
|
||||
input_ids: torch.Tensor
|
||||
attention_mask: torch.Tensor
|
||||
position_ids: torch.Tensor
|
||||
pixel_values: Optional[torch.Tensor]
|
||||
image_hidden_states: Optional[torch.Tensor]
|
||||
image_attention_mask: Optional[torch.Tensor]
|
||||
past_key_values: Optional[List[Tuple]]
|
||||
|
||||
# All tokens
|
||||
all_input_ids: List[torch.Tensor]
|
||||
|
||||
# Lengths of all generations present in the batch
|
||||
input_lengths: List[int]
|
||||
prefix_offsets: List[int]
|
||||
read_offsets: List[int]
|
||||
|
||||
# Generation helpers
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
|
||||
# Metadata used for padding
|
||||
max_input_length: int
|
||||
padding_right_offset: int
|
||||
|
||||
# Maximum number of tokens this batch will grow to
|
||||
max_tokens: int
|
||||
|
||||
# Past metadata
|
||||
keys_head_dim_last: bool = True
|
||||
|
||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||
return generate_pb2.CachedBatch(
|
||||
id=self.batch_id,
|
||||
request_ids=[r.id for r in self.requests],
|
||||
size=len(self),
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "IdeficsCausalLMBatch":
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_pb_processor(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
processor: ProcessorMixin, # Hack
|
||||
config,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "IdeficsCausalLMBatch":
|
||||
inputs = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
# Parse batch
|
||||
max_truncation = 0
|
||||
padding_right_offset = 0
|
||||
max_decode_tokens = 0
|
||||
for i, r in enumerate(pb.requests):
|
||||
requests_idx_mapping[r.id] = i
|
||||
inputs.append(r.input_chunks.chunks)
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||
)
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||
padding_right_offset = max(
|
||||
padding_right_offset, stopping_criteria.max_new_tokens
|
||||
)
|
||||
|
||||
# TODO Check impact on idefics
|
||||
prompts = []
|
||||
for inp in inputs:
|
||||
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
||||
prompt = []
|
||||
for chunk in inp:
|
||||
chunk_type = chunk.WhichOneof("chunk")
|
||||
if chunk_type == "text":
|
||||
prompt.append(chunk.text)
|
||||
elif chunk_type == "image":
|
||||
image = Image.open(BytesIO(chunk.image.data))
|
||||
prompt.append(image)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||
prompts.append(prompt)
|
||||
|
||||
# The processor replaces the call to tokenizer, and
|
||||
# a/ takes care of fetching images from the URL
|
||||
# b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model
|
||||
tokenized_inputs = processor(
|
||||
prompts,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_truncation,
|
||||
# TODO Check impact on idefics
|
||||
# add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
|
||||
).to(device)
|
||||
for _ in pb.requests:
|
||||
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||
prefix_offsets.append(
|
||||
input_len - 5
|
||||
) # To decode without potential fallbacks errors
|
||||
read_offsets.append(
|
||||
input_len
|
||||
) # To decode without potential fallbacks errors
|
||||
|
||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||
max_input_length = input_lengths.max()
|
||||
|
||||
input_ids = tokenized_inputs["input_ids"]
|
||||
pixel_values = tokenized_inputs.get("pixel_values", None)
|
||||
image_hidden_states = None
|
||||
# Allocate maximum attention_mask
|
||||
attention_mask = input_ids.new_zeros(
|
||||
(pb.size, max_input_length + padding_right_offset)
|
||||
)
|
||||
# Copy tokenizer attention_mask into fully allocated attention_mask
|
||||
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
|
||||
# Do the same for image_attention_mask
|
||||
if pixel_values is None:
|
||||
image_attention_mask = None
|
||||
else:
|
||||
image_attention_mask = input_ids.new_zeros(
|
||||
(
|
||||
pb.size,
|
||||
max_input_length + padding_right_offset,
|
||||
pixel_values.size(1),
|
||||
)
|
||||
)
|
||||
image_attention_mask[:, :max_input_length, :] = tokenized_inputs[
|
||||
"image_attention_mask"
|
||||
]
|
||||
|
||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||
all_input_ids = tokenized_inputs["input_ids"].T.split(
|
||||
1, dim=1
|
||||
) # It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list
|
||||
|
||||
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
pixel_values=pixel_values,
|
||||
image_hidden_states=image_hidden_states,
|
||||
image_attention_mask=image_attention_mask,
|
||||
past_key_values=None,
|
||||
all_input_ids=list(all_input_ids),
|
||||
input_lengths=input_lengths.tolist(),
|
||||
prefix_offsets=prefix_offsets,
|
||||
read_offsets=read_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
max_input_length=max_input_length.item(),
|
||||
padding_right_offset=padding_right_offset,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]:
|
||||
# It deletes requests from the batch. For instance when client lost connection
|
||||
if len(request_ids) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
if len(request_ids) == len(self):
|
||||
return self
|
||||
|
||||
keep_indices = []
|
||||
|
||||
# New values after filtering
|
||||
requests_idx_mapping = {}
|
||||
requests = []
|
||||
input_lengths = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
all_input_ids = []
|
||||
max_input_length = 0
|
||||
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
||||
total_remaining_decode_tokens = 0
|
||||
new_padding_right_offset = 0
|
||||
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
requests_idx_mapping[request_id] = i
|
||||
keep_indices.append(idx)
|
||||
|
||||
requests.append(self.requests[idx])
|
||||
prefix_offsets.append(self.prefix_offsets[idx])
|
||||
read_offsets.append(self.read_offsets[idx])
|
||||
all_input_ids.append(self.all_input_ids[idx])
|
||||
|
||||
request_input_length = self.input_lengths[idx]
|
||||
input_lengths.append(request_input_length)
|
||||
max_input_length = max(max_input_length, request_input_length)
|
||||
|
||||
next_token_choosers.append(self.next_token_choosers[idx])
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
remaining_decode_tokens = (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
total_remaining_decode_tokens += remaining_decode_tokens
|
||||
new_padding_right_offset = max(
|
||||
new_padding_right_offset, remaining_decode_tokens
|
||||
)
|
||||
|
||||
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
||||
input_ids = self.input_ids[keep_indices]
|
||||
position_ids = self.position_ids[keep_indices]
|
||||
self.attention_mask = self.attention_mask[
|
||||
keep_indices,
|
||||
-(self.padding_right_offset + max_input_length) : (
|
||||
self.attention_mask.shape[1] - self.padding_right_offset
|
||||
)
|
||||
+ new_padding_right_offset,
|
||||
]
|
||||
# Do the same for pixel_values and image_attention_mask
|
||||
pixel_values = self.pixel_values[keep_indices]
|
||||
self.image_attention_mask = self.image_attention_mask[
|
||||
keep_indices,
|
||||
-(self.padding_right_offset + max_input_length) : (
|
||||
self.image_attention_mask.shape[1] - self.padding_right_offset
|
||||
)
|
||||
+ new_padding_right_offset,
|
||||
:,
|
||||
]
|
||||
if self.image_hidden_states is None:
|
||||
image_hidden_states = None
|
||||
else:
|
||||
image_hidden_states = self.image_hidden_states[keep_indices]
|
||||
|
||||
# Ensure that past_key_values tensors can be updated in-place
|
||||
if type(self.past_key_values[0]) is tuple:
|
||||
self.past_key_values = [list(layer) for layer in self.past_key_values]
|
||||
|
||||
# Update tensors in-place to allow incremental garbage collection
|
||||
past_kv_length = max_input_length - 1
|
||||
for layer in self.past_key_values:
|
||||
past_keys, past_values = layer
|
||||
if len(past_keys.shape) == 3:
|
||||
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
|
||||
past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
|
||||
past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
|
||||
if self.keys_head_dim_last:
|
||||
layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
|
||||
else:
|
||||
layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
|
||||
del past_keys
|
||||
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
|
||||
del past_values
|
||||
|
||||
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
|
||||
|
||||
self.requests = requests
|
||||
self.requests_idx_mapping = requests_idx_mapping
|
||||
self.input_ids = input_ids
|
||||
self.pixel_values = pixel_values
|
||||
self.image_hidden_states = image_hidden_states
|
||||
self.position_ids = position_ids
|
||||
self.all_input_ids = all_input_ids
|
||||
self.input_lengths = input_lengths
|
||||
self.prefix_offsets = prefix_offsets
|
||||
self.read_offsets = read_offsets
|
||||
self.next_token_choosers = next_token_choosers
|
||||
self.stopping_criterias = stopping_criterias
|
||||
self.max_input_length = max_input_length
|
||||
self.padding_right_offset = new_padding_right_offset
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(
|
||||
cls, batches: List["IdeficsCausalLMBatch"]
|
||||
) -> "IdeficsCausalLMBatch":
|
||||
# It adds new requests to the batch
|
||||
# Used for padding
|
||||
total_batch_size = 0
|
||||
max_input_length = 0
|
||||
max_num_images = 0
|
||||
padding_right_offset = 0
|
||||
for batch in batches:
|
||||
total_batch_size += len(batch)
|
||||
max_input_length = max(max_input_length, batch.max_input_length)
|
||||
max_num_images = max(max_num_images, batch.pixel_values.size(1))
|
||||
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
|
||||
|
||||
# Batch attributes
|
||||
requests = []
|
||||
requests_idx_mapping = {}
|
||||
input_lengths = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
all_input_ids = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
max_tokens = 0
|
||||
|
||||
# Batch tensors
|
||||
input_ids = None
|
||||
attention_mask = None
|
||||
position_ids = None
|
||||
pixel_values = None
|
||||
image_hidden_states = None
|
||||
image_attention_mask = None
|
||||
past_key_values = []
|
||||
|
||||
# Used for slicing correctly inside the tensors
|
||||
# Equivalent to a cumsum on batch sizes
|
||||
start_index = 0
|
||||
for i, batch in enumerate(batches):
|
||||
requests.extend(batch.requests)
|
||||
input_lengths.extend(batch.input_lengths)
|
||||
prefix_offsets.extend(batch.prefix_offsets)
|
||||
read_offsets.extend(batch.read_offsets)
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
next_token_choosers.extend(batch.next_token_choosers)
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
|
||||
if i == 0:
|
||||
requests_idx_mapping = batch.requests_idx_mapping
|
||||
else:
|
||||
# We need to offset the mapping for each batch by the cumulative batch size
|
||||
for k, v in batch.requests_idx_mapping.items():
|
||||
requests_idx_mapping[k] = v + start_index
|
||||
|
||||
# Slicing end index for this batch
|
||||
end_index = start_index + len(batch)
|
||||
|
||||
# We only concatenate batches that did at least one step
|
||||
if batch.past_key_values is None:
|
||||
raise ValueError("only concatenate prefilled batches")
|
||||
|
||||
# Create empty tensor
|
||||
# input_ids is always of shape [batch_size, 1]
|
||||
# We do not need to pad it
|
||||
if input_ids is None:
|
||||
input_ids = batch.input_ids.new_empty((total_batch_size, 1))
|
||||
# Copy to correct indices
|
||||
input_ids[start_index:end_index] = batch.input_ids
|
||||
|
||||
# Create padded tensor
|
||||
if attention_mask is None:
|
||||
attention_mask = batch.attention_mask.new_zeros(
|
||||
(total_batch_size, max_input_length + padding_right_offset),
|
||||
)
|
||||
|
||||
curr_batch_max_num_images = batch.pixel_values.size(1)
|
||||
if pixel_values is None:
|
||||
pixel_values = batch.pixel_values.new_zeros(
|
||||
(total_batch_size, max_num_images, 3, 224, 224)
|
||||
)
|
||||
pixel_values[start_index:end_index, :curr_batch_max_num_images] = (
|
||||
batch.pixel_values
|
||||
)
|
||||
|
||||
if image_attention_mask is None:
|
||||
image_attention_mask = batch.image_attention_mask.new_zeros(
|
||||
(
|
||||
total_batch_size,
|
||||
max_input_length + padding_right_offset,
|
||||
max_num_images,
|
||||
)
|
||||
)
|
||||
|
||||
# We need to slice the attention mask to remove padding from previous steps
|
||||
# and to remove unused allocated space
|
||||
left_offset = max_input_length - batch.max_input_length
|
||||
batch_left_offset = (
|
||||
batch.attention_mask.shape[1]
|
||||
- batch.max_input_length
|
||||
- batch.padding_right_offset
|
||||
)
|
||||
attention_mask[
|
||||
start_index:end_index,
|
||||
left_offset:-padding_right_offset,
|
||||
] = batch.attention_mask[
|
||||
:,
|
||||
batch_left_offset : -batch.padding_right_offset,
|
||||
]
|
||||
image_attention_mask[
|
||||
start_index:end_index,
|
||||
left_offset:-padding_right_offset,
|
||||
:curr_batch_max_num_images,
|
||||
] = batch.image_attention_mask[
|
||||
:, batch_left_offset : -batch.padding_right_offset, :
|
||||
]
|
||||
|
||||
# Create empty tensor
|
||||
# position_ids is always of shape [batch_size, 1]
|
||||
if position_ids is None:
|
||||
position_ids = batch.position_ids.new_empty((total_batch_size, 1))
|
||||
position_ids[start_index:end_index] = batch.position_ids
|
||||
|
||||
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
|
||||
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
|
||||
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
|
||||
# And ensure that we can update tensors in-place
|
||||
if isinstance(batch.past_key_values[0], tuple):
|
||||
batch.past_key_values = [
|
||||
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
|
||||
for layer in batch.past_key_values
|
||||
]
|
||||
elif len(batch.past_key_values[0][0].shape) == 3:
|
||||
for layer in batch.past_key_values:
|
||||
for k, t in enumerate(layer):
|
||||
layer[k] = t.view(len(batch), -1, *t.shape[-2:])
|
||||
|
||||
# Add eventual padding tokens that were added while concatenating
|
||||
max_tokens += batch.max_tokens + (
|
||||
max_input_length - batch.max_input_length
|
||||
) * len(batch)
|
||||
|
||||
start_index = end_index
|
||||
|
||||
first_past_kvs = batches[0].past_key_values
|
||||
_, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape
|
||||
|
||||
padded_past_values_shape = (
|
||||
total_batch_size,
|
||||
num_heads,
|
||||
max_input_length - 1,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
if batches[0].keys_head_dim_last:
|
||||
padded_past_keys_shape = padded_past_values_shape
|
||||
else:
|
||||
# seq_length is last for BLOOM
|
||||
padded_past_keys_shape = (
|
||||
total_batch_size,
|
||||
num_heads,
|
||||
head_dim,
|
||||
max_input_length - 1,
|
||||
)
|
||||
|
||||
# Iterate over attention layers
|
||||
# Concatenate past key values layer by layer to allow incremental garbage collection
|
||||
for j in range(len(first_past_kvs)):
|
||||
padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
|
||||
start_index = 0
|
||||
for batch in batches:
|
||||
past_keys = batch.past_key_values[j][0]
|
||||
# Clear reference to the original tensor
|
||||
batch.past_key_values[j][0] = None
|
||||
|
||||
# Slicing end index for this batch
|
||||
end_index = start_index + len(batch)
|
||||
# We slice the keys to remove the padding from previous batches
|
||||
past_seq_len = batch.max_input_length - 1
|
||||
if batch.keys_head_dim_last:
|
||||
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
|
||||
past_keys[:, :, -past_seq_len:, :]
|
||||
)
|
||||
else:
|
||||
# BLOOM case
|
||||
padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
|
||||
past_keys[:, :, :, -past_seq_len:]
|
||||
)
|
||||
del past_keys
|
||||
|
||||
start_index = end_index
|
||||
|
||||
padded_past_values = first_past_kvs[j][1].new_zeros(
|
||||
padded_past_values_shape
|
||||
)
|
||||
start_index = 0
|
||||
for batch in batches:
|
||||
past_values = batch.past_key_values[j][1]
|
||||
# Clear reference to the original tensor
|
||||
batch.past_key_values[j][1] = None
|
||||
|
||||
# Slicing end index for this batch
|
||||
end_index = start_index + len(batch)
|
||||
# We slice the past values to remove the padding from previous batches
|
||||
past_seq_len = batch.max_input_length - 1
|
||||
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
|
||||
past_values[:, :, -past_seq_len:, :]
|
||||
)
|
||||
del past_values
|
||||
|
||||
# Update values
|
||||
start_index = end_index
|
||||
|
||||
past_key_values.append([padded_past_keys, padded_past_values])
|
||||
|
||||
return cls(
|
||||
batch_id=batches[0].batch_id,
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
pixel_values=pixel_values,
|
||||
image_hidden_states=image_hidden_states,
|
||||
image_attention_mask=image_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
all_input_ids=all_input_ids,
|
||||
input_lengths=input_lengths,
|
||||
prefix_offsets=prefix_offsets,
|
||||
read_offsets=read_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
max_input_length=max_input_length,
|
||||
padding_right_offset=padding_right_offset,
|
||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
|
||||
|
||||
class IdeficsCausalLM(Model):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.quantize = quantize
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
device = torch.device("hpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
self.device, self.dtype = device, dtype
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
config.vision_config.quantize = quantize
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
weights_loader = get_loader(
|
||||
quantize=quantize, model_id=model_id, revision=revision
|
||||
)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
process_group=self.process_group,
|
||||
weights_loader=weights_loader,
|
||||
)
|
||||
|
||||
model = IdeficsForVisionText2Text(config, weights)
|
||||
|
||||
self.config = config
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super().__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[IdeficsCausalLMBatch]:
|
||||
return IdeficsCausalLMBatch
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pixel_values,
|
||||
image_hidden_states,
|
||||
image_attention_mask,
|
||||
past_key_values: Optional = None,
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
# Model Forward
|
||||
kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
"image_hidden_states": image_hidden_states,
|
||||
"image_attention_mask": image_attention_mask,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": True,
|
||||
"return_dict": True,
|
||||
}
|
||||
if self.has_position_ids:
|
||||
kwargs["position_ids"] = position_ids
|
||||
|
||||
outputs, speculative_logits = self.model.forward(**kwargs)
|
||||
return (
|
||||
outputs.logits,
|
||||
speculative_logits,
|
||||
outputs.past_key_values,
|
||||
outputs.image_hidden_states,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
self, batch: IdeficsCausalLMBatch
|
||||
) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch], Tuple[int, int]]:
|
||||
start = time.time_ns()
|
||||
# slice the attention mask to the correct shape
|
||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||
if batch.image_attention_mask is None:
|
||||
image_attention_mask = None
|
||||
else:
|
||||
if batch.input_ids.size(1) == 1:
|
||||
# THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images),
|
||||
# but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension
|
||||
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
|
||||
# token need to attend to the encoder hidden states (i.e. the vision encoder)
|
||||
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
|
||||
image_attention_mask = batch.image_attention_mask[
|
||||
:, -(batch.padding_right_offset + 1)
|
||||
].unsqueeze(1)
|
||||
else:
|
||||
image_attention_mask = batch.image_attention_mask[
|
||||
:, : -batch.padding_right_offset
|
||||
]
|
||||
|
||||
logits, speculative_logits, past, image_hidden_states = self.forward(
|
||||
input_ids=batch.input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=batch.position_ids,
|
||||
pixel_values=batch.pixel_values,
|
||||
image_hidden_states=batch.image_hidden_states,
|
||||
image_attention_mask=image_attention_mask,
|
||||
past_key_values=batch.past_key_values,
|
||||
)
|
||||
# Hardcoded remove image tokens
|
||||
logits[:, 32000:32001] = torch.finfo(logits.dtype).min
|
||||
|
||||
start_decode = time.time_ns()
|
||||
|
||||
# Results
|
||||
generations: List[Generation] = []
|
||||
stopped = True
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.requests,
|
||||
batch.input_lengths,
|
||||
batch.prefix_offsets,
|
||||
batch.read_offsets,
|
||||
logits,
|
||||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
batch.all_input_ids,
|
||||
)
|
||||
|
||||
# For each member of the batch
|
||||
for i, (
|
||||
request,
|
||||
input_length,
|
||||
prefix_offset,
|
||||
read_offset,
|
||||
logits,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
all_input_ids,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(
|
||||
all_input_ids.view(1, -1), logits[-1:, :]
|
||||
)
|
||||
|
||||
# Append next token to all tokens
|
||||
all_input_ids = torch.cat([all_input_ids, next_token_id])
|
||||
new_input_length = input_length + 1
|
||||
|
||||
# Generated token
|
||||
next_token_logprob = logprobs[-1, next_token_id]
|
||||
next_token_id_squeezed = next_token_id.squeeze()
|
||||
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||
all_input_ids[:, 0], prefix_offset, read_offset
|
||||
)
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(
|
||||
next_token_id_squeezed,
|
||||
next_token_text,
|
||||
)
|
||||
|
||||
if not stop:
|
||||
stopped = False
|
||||
|
||||
# Shard generations
|
||||
# All generations will be appended in the rust sharded client
|
||||
if i % self.world_size == self.rank:
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
output_text, _, _ = self.decode_token(
|
||||
all_input_ids[:, 0],
|
||||
prefix_offset=len(all_input_ids)
|
||||
- stopping_criteria.current_tokens
|
||||
- 1,
|
||||
read_offset=len(all_input_ids)
|
||||
- stopping_criteria.current_tokens,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
generated_text = GeneratedText(
|
||||
output_text, stopping_criteria.current_tokens, reason, seed
|
||||
)
|
||||
else:
|
||||
generated_text = None
|
||||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
prefill_logprobs = [float("nan")] + torch.log_softmax(
|
||||
logits, -1
|
||||
).gather(1, all_input_ids[1:]).squeeze(1)[
|
||||
-new_input_length:-1
|
||||
].tolist()
|
||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
prefill_tokens = Tokens(
|
||||
prefill_token_ids,
|
||||
prefill_logprobs,
|
||||
prefill_texts,
|
||||
is_special=[],
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
top_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
Tokens(
|
||||
[next_token_id_squeezed],
|
||||
[next_token_logprob],
|
||||
[next_token_text],
|
||||
[next_token_id_squeezed.item() in self.all_special_ids],
|
||||
),
|
||||
generated_text,
|
||||
top_tokens,
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
|
||||
# Update values
|
||||
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
|
||||
next_token_id_squeezed.item()
|
||||
)
|
||||
batch.input_ids[i, 0] = next_token_id
|
||||
batch.all_input_ids[i] = all_input_ids
|
||||
batch.input_lengths[i] = new_input_length
|
||||
batch.prefix_offsets[i] = prefix_offset
|
||||
batch.read_offsets[i] = read_offset
|
||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if stopped:
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, None, (forward_ns, decode_ns)
|
||||
|
||||
# Slice unused values from prefill
|
||||
batch.input_ids = batch.input_ids[:, :1]
|
||||
|
||||
# Update attention_mask as we added a new token to input_ids
|
||||
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
||||
batch.image_attention_mask[:, -batch.padding_right_offset, :] = (
|
||||
batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
|
||||
)
|
||||
# Decrease right offset
|
||||
batch.padding_right_offset -= 1
|
||||
|
||||
# Update position_ids
|
||||
batch.position_ids = batch.position_ids[:, -1:] + 1
|
||||
|
||||
# Update past key values
|
||||
batch.past_key_values = past
|
||||
batch.image_hidden_states = image_hidden_states
|
||||
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, batch, (forward_ns, decode_ns)
|
@ -1,814 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
from typing import Optional
|
||||
from text_generation_server.models.custom_modeling.mamba_modeling import (
|
||||
MambaConfig,
|
||||
)
|
||||
from loguru import logger
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.models.globals import CUDA_GRAPHS, MEM_POOL
|
||||
import time
|
||||
from text_generation_server.models.custom_modeling.mamba_modeling import (
|
||||
MambaModel,
|
||||
InferenceParams,
|
||||
)
|
||||
from text_generation_server.models import Model
|
||||
from typing import Any, List, Tuple, Type, Dict
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
Tokens,
|
||||
Generation,
|
||||
GeneratedText,
|
||||
)
|
||||
from text_generation_server.utils.chunks import concat_text_chunks
|
||||
from text_generation_server.utils.quantization import get_loader
|
||||
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
||||
from dataclasses import dataclass
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria
|
||||
|
||||
|
||||
def new_inference_params(
|
||||
n_blocks: int,
|
||||
batch_size: int,
|
||||
d_inner: int,
|
||||
d_conv: int,
|
||||
d_state: int,
|
||||
seqlen_offset: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
max_seqlen = 0
|
||||
conv_states = torch.zeros(
|
||||
(
|
||||
n_blocks,
|
||||
batch_size,
|
||||
d_inner,
|
||||
d_conv,
|
||||
),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
ssm_states = torch.zeros(
|
||||
(
|
||||
n_blocks,
|
||||
batch_size,
|
||||
d_inner,
|
||||
d_state,
|
||||
),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
inference_params = InferenceParams(
|
||||
max_seqlen=max_seqlen,
|
||||
max_batch_size=batch_size,
|
||||
seqlen_offset=seqlen_offset,
|
||||
conv_states=conv_states,
|
||||
ssm_states=ssm_states,
|
||||
)
|
||||
return inference_params
|
||||
|
||||
|
||||
@dataclass
|
||||
class MambaBatch(Batch):
|
||||
batch_id: int
|
||||
requests: List[generate_pb2.Request]
|
||||
requests_idx_mapping: Dict[int, int]
|
||||
|
||||
# Decoder values
|
||||
input_ids: torch.Tensor
|
||||
|
||||
# All tokens
|
||||
all_input_ids: List[torch.Tensor]
|
||||
|
||||
# Lengths of all generations present in the batch
|
||||
input_lengths: List[int]
|
||||
prefix_offsets: List[int]
|
||||
read_offsets: List[int]
|
||||
|
||||
# Generation helpers
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
top_n_tokens: List[int]
|
||||
top_n_tokens_tensor: torch.Tensor
|
||||
|
||||
# Metadata used for padding
|
||||
max_input_length: int
|
||||
padding_right_offset: int
|
||||
|
||||
# Maximum number of tokens this batch will grow to
|
||||
max_tokens: int
|
||||
|
||||
# Past metadata
|
||||
keys_head_dim_last: bool = True
|
||||
|
||||
# Inference params
|
||||
inference_params: Optional[Dict[str, Any]] = None
|
||||
|
||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||
return generate_pb2.CachedBatch(
|
||||
id=self.batch_id,
|
||||
request_ids=[r.id for r in self.requests],
|
||||
size=len(self),
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "MambaBatch":
|
||||
inputs = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
# Parse batch
|
||||
max_truncation = 0
|
||||
padding_right_offset = 0
|
||||
max_decode_tokens = 0
|
||||
for i, r in enumerate(pb.requests):
|
||||
requests_idx_mapping[r.id] = i
|
||||
inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||
)
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(r.top_n_tokens)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||
padding_right_offset = max(
|
||||
padding_right_offset, stopping_criteria.max_new_tokens
|
||||
)
|
||||
|
||||
tokenized_inputs = tokenizer(
|
||||
inputs,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
return_token_type_ids=False,
|
||||
truncation=True,
|
||||
max_length=max_truncation,
|
||||
).to(device)
|
||||
for _ in pb.requests:
|
||||
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||
prefix_offsets.append(input_len - 5)
|
||||
read_offsets.append(input_len)
|
||||
|
||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||
max_input_length = input_lengths.max()
|
||||
input_ids = tokenized_inputs["input_ids"]
|
||||
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
||||
top_n_tokens_tensor = torch.tensor(
|
||||
top_n_tokens, device=device, dtype=torch.int64
|
||||
)
|
||||
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
# past_input_ids=None,
|
||||
all_input_ids=list(all_input_ids),
|
||||
input_lengths=input_lengths.tolist(),
|
||||
prefix_offsets=prefix_offsets,
|
||||
read_offsets=read_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
max_input_length=max_input_length.item(),
|
||||
padding_right_offset=padding_right_offset,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]:
|
||||
if len(request_ids) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
if len(request_ids) == len(self):
|
||||
return self
|
||||
|
||||
keep_indices = []
|
||||
|
||||
# New values after filtering
|
||||
requests_idx_mapping = {}
|
||||
requests = []
|
||||
input_lengths = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
all_input_ids = []
|
||||
max_input_length = 0
|
||||
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
total_remaining_decode_tokens = 0
|
||||
new_padding_right_offset = 0
|
||||
|
||||
indices = []
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
requests_idx_mapping[request_id] = i
|
||||
keep_indices.append(idx)
|
||||
|
||||
requests.append(self.requests[idx])
|
||||
prefix_offsets.append(self.prefix_offsets[idx])
|
||||
read_offsets.append(self.read_offsets[idx])
|
||||
all_input_ids.append(self.all_input_ids[idx])
|
||||
|
||||
request_input_length = self.input_lengths[idx]
|
||||
input_lengths.append(request_input_length)
|
||||
max_input_length = max(max_input_length, request_input_length)
|
||||
indices.append(idx)
|
||||
|
||||
next_token_choosers.append(self.next_token_choosers[idx])
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(self.top_n_tokens[idx])
|
||||
remaining_decode_tokens = (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
total_remaining_decode_tokens += remaining_decode_tokens
|
||||
new_padding_right_offset = max(
|
||||
new_padding_right_offset, remaining_decode_tokens
|
||||
)
|
||||
|
||||
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
||||
input_ids = self.input_ids[keep_indices]
|
||||
|
||||
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
|
||||
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
|
||||
|
||||
self.requests = requests
|
||||
self.requests_idx_mapping = requests_idx_mapping
|
||||
self.input_ids = input_ids
|
||||
self.all_input_ids = all_input_ids
|
||||
self.input_lengths = input_lengths
|
||||
self.prefix_offsets = prefix_offsets
|
||||
self.read_offsets = read_offsets
|
||||
self.next_token_choosers = next_token_choosers
|
||||
self.stopping_criterias = stopping_criterias
|
||||
self.top_n_tokens = top_n_tokens
|
||||
self.top_n_tokens_tensor = top_n_tokens_tensor
|
||||
self.max_input_length = max_input_length
|
||||
self.padding_right_offset = new_padding_right_offset
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
# TODO
|
||||
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
|
||||
self.inference_params.conv_states = self.inference_params.conv_states[
|
||||
:, indices
|
||||
]
|
||||
self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices]
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch":
|
||||
# Used for padding
|
||||
total_batch_size = 0
|
||||
max_input_length = 0
|
||||
padding_right_offset = 0
|
||||
for batch in batches:
|
||||
total_batch_size += len(batch)
|
||||
max_input_length = max(max_input_length, batch.max_input_length)
|
||||
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
|
||||
|
||||
# Batch attributes
|
||||
requests = []
|
||||
requests_idx_mapping = {}
|
||||
input_lengths = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
all_input_ids = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
max_tokens = 0
|
||||
seqlen_offset = 0
|
||||
|
||||
(n_blocks, _, d_inner, d_conv) = batches[0].inference_params.conv_states.shape
|
||||
(_, _, _, d_state) = batches[0].inference_params.ssm_states.shape
|
||||
dtype = batches[0].inference_params.conv_states.dtype
|
||||
device = batches[0].inference_params.conv_states.device
|
||||
inference_params = new_inference_params(
|
||||
n_blocks=n_blocks,
|
||||
batch_size=total_batch_size,
|
||||
d_state=d_state,
|
||||
d_conv=d_conv,
|
||||
d_inner=d_inner,
|
||||
seqlen_offset=seqlen_offset,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Batch tensors
|
||||
input_ids = None
|
||||
top_n_tokens_tensor = None
|
||||
|
||||
# Used for slicing correctly inside the tensors
|
||||
# Equivalent to a cumsum on batch sizes
|
||||
start_index = 0
|
||||
for i, batch in enumerate(batches):
|
||||
requests.extend(batch.requests)
|
||||
input_lengths.extend(batch.input_lengths)
|
||||
prefix_offsets.extend(batch.prefix_offsets)
|
||||
read_offsets.extend(batch.read_offsets)
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
next_token_choosers.extend(batch.next_token_choosers)
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
top_n_tokens.extend(batch.top_n_tokens)
|
||||
|
||||
if i == 0:
|
||||
requests_idx_mapping = batch.requests_idx_mapping
|
||||
else:
|
||||
# We need to offset the mapping for each batch by the cumulative batch size
|
||||
for k, v in batch.requests_idx_mapping.items():
|
||||
requests_idx_mapping[k] = v + start_index
|
||||
|
||||
# Slicing end index for this batch
|
||||
end_index = start_index + len(batch)
|
||||
|
||||
# Create empty tensor
|
||||
# input_ids is always of shape [batch_size, 1]
|
||||
# We do not need to pad it
|
||||
if input_ids is None:
|
||||
input_ids = batch.input_ids.new_empty((total_batch_size, 1))
|
||||
# Copy to correct indices
|
||||
input_ids[start_index:end_index] = batch.input_ids
|
||||
|
||||
if top_n_tokens_tensor is None:
|
||||
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||
total_batch_size,
|
||||
)
|
||||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||
|
||||
# Add eventual padding tokens that were added while concatenating
|
||||
max_tokens += batch.max_tokens + (
|
||||
max_input_length - batch.max_input_length
|
||||
) * len(batch)
|
||||
|
||||
inference_params.max_seqlen = max(
|
||||
inference_params.max_seqlen, batch.inference_params.max_seqlen
|
||||
)
|
||||
assert batch.inference_params.seqlen_offset != 0, "Invalid seqlen offset"
|
||||
inference_params.seqlen_offset = max(
|
||||
inference_params.seqlen_offset, batch.inference_params.seqlen_offset
|
||||
)
|
||||
|
||||
inference_params.conv_states[:, start_index:end_index] = (
|
||||
batch.inference_params.conv_states
|
||||
)
|
||||
inference_params.ssm_states[:, start_index:end_index] = (
|
||||
batch.inference_params.ssm_states
|
||||
)
|
||||
|
||||
start_index = end_index
|
||||
|
||||
return cls(
|
||||
batch_id=batches[0].batch_id,
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
all_input_ids=all_input_ids,
|
||||
input_lengths=input_lengths,
|
||||
prefix_offsets=prefix_offsets,
|
||||
read_offsets=read_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
max_input_length=max_input_length,
|
||||
padding_right_offset=padding_right_offset,
|
||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||
max_tokens=max_tokens,
|
||||
inference_params=inference_params,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
|
||||
|
||||
class Mamba(Model):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.quantize = quantize
|
||||
self.process_group, _rank, world_size = initialize_torch_distributed()
|
||||
if world_size > 1:
|
||||
raise RuntimeError("Mamba does not support Tensor Parallelism (TP)")
|
||||
self.cuda_graphs = {}
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
# Bf16 is important. In f16 accumulations in the matmul are causing
|
||||
# differences while the server is under load.
|
||||
# This is detectable by the integration load test
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"EleutherAI/gpt-neox-20b",
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config = MambaConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
tokenizer.bos_token_id = config.bos_token_id
|
||||
tokenizer.eos_token_id = config.eos_token_id
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
weights_loader = get_loader(
|
||||
quantize=quantize, model_id=model_id, revision=revision
|
||||
)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames,
|
||||
device,
|
||||
dtype,
|
||||
process_group=self.process_group,
|
||||
weights_loader=weights_loader,
|
||||
)
|
||||
model = MambaModel(config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(Mamba, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[MambaBatch]:
|
||||
return MambaBatch
|
||||
|
||||
def warmup(self, batch) -> Optional[int]:
|
||||
# TODO: implement warmup for Mamba if needed
|
||||
if CUDA_GRAPHS:
|
||||
if self.speculate is None or self.speculate == 0:
|
||||
try:
|
||||
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
|
||||
# Warmup cuda graphs
|
||||
for bs in CUDA_GRAPHS:
|
||||
self.cuda_graph_warmup(bs)
|
||||
except Exception:
|
||||
logger.exception("Decode cuda graph warmup failed")
|
||||
else:
|
||||
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
||||
|
||||
return None
|
||||
|
||||
def cuda_graph_warmup(self, batch_size: int):
|
||||
input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)
|
||||
n_blocks = len(self.model.blocks)
|
||||
|
||||
d_state = self.model.config.d_state
|
||||
d_conv = self.model.config.d_conv
|
||||
# Inner takes the expand multiplication
|
||||
d_inner = self.model.config.d_inner
|
||||
|
||||
# Important seqlen_offset to go through the update mecanism with the state
|
||||
seqlen_offset = 1
|
||||
inference_params = new_inference_params(
|
||||
n_blocks=n_blocks,
|
||||
batch_size=batch_size,
|
||||
d_state=d_state,
|
||||
d_conv=d_conv,
|
||||
d_inner=d_inner,
|
||||
seqlen_offset=seqlen_offset,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# Run once outside to warmup
|
||||
self.model.forward(input_ids=input_ids, inference_params=inference_params)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids, inference_params=inference_params
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
graph_dict = {
|
||||
"input_ids": input_ids,
|
||||
"inference_params": inference_params,
|
||||
"graph": graph,
|
||||
"logits": logits,
|
||||
"speculative_logits": speculative_logits,
|
||||
}
|
||||
self.cuda_graphs[batch_size] = graph_dict
|
||||
|
||||
def tunableop_warmup(self, batch_size: int, seqlen: int):
|
||||
input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)
|
||||
n_blocks = len(self.model.blocks)
|
||||
|
||||
d_state = self.model.config.d_state
|
||||
d_conv = self.model.config.d_conv
|
||||
# Inner takes the expand multiplication
|
||||
d_inner = self.model.config.d_inner
|
||||
|
||||
# Important seqlen_offset to go through the update mecanism with the state
|
||||
seqlen_offset = 1
|
||||
inference_params = new_inference_params(
|
||||
n_blocks=n_blocks,
|
||||
batch_size=seqlen,
|
||||
d_state=d_state,
|
||||
d_conv=d_conv,
|
||||
d_inner=d_inner,
|
||||
seqlen_offset=seqlen_offset,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
self.model.forward(input_ids=input_ids, inference_params=inference_params)
|
||||
|
||||
def forward(
|
||||
self, input_ids: torch.Tensor, inference_params: Any
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
bs = input_ids.shape[0]
|
||||
padded_bs = bs
|
||||
if bs == 3:
|
||||
padded_bs = 4
|
||||
elif 3 < bs <= 8:
|
||||
padded_bs = 8
|
||||
elif bs > 8:
|
||||
padded_bs = (bs + 7) // 8 * 8
|
||||
|
||||
# Try to find an associated cuda graph
|
||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||
is_prefill = inference_params is None or inference_params.seqlen_offset == 0
|
||||
|
||||
if is_prefill or cuda_graph is None:
|
||||
return self.model(
|
||||
input_ids,
|
||||
inference_params=inference_params,
|
||||
)
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
# Static inputs are potentially padded
|
||||
cuda_graph["input_ids"][:bs] = input_ids
|
||||
cuda_graph["inference_params"].conv_states[
|
||||
:, :bs
|
||||
] = inference_params.conv_states
|
||||
cuda_graph["inference_params"].ssm_states[:, :bs] = inference_params.ssm_states
|
||||
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
inference_params.conv_states.copy_(
|
||||
cuda_graph["inference_params"].conv_states[:, :bs]
|
||||
)
|
||||
inference_params.ssm_states.copy_(
|
||||
cuda_graph["inference_params"].ssm_states[:, :bs]
|
||||
)
|
||||
# Slice output to the correct shape
|
||||
speculative_logits = (
|
||||
cuda_graph["speculative_logits"][:bs]
|
||||
if cuda_graph["speculative_logits"] is not None
|
||||
else None
|
||||
)
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
return logits, speculative_logits
|
||||
|
||||
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
||||
start = time.time_ns()
|
||||
input_ids = (
|
||||
batch.input_ids
|
||||
) # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
|
||||
|
||||
batch_size, max_seqlen = input_ids.shape
|
||||
# Inference params
|
||||
|
||||
if batch.inference_params is None:
|
||||
# 0 is important here
|
||||
seqlen_offset = 0
|
||||
n_blocks = len(self.model.blocks)
|
||||
d_state = self.model.config.d_state
|
||||
d_conv = self.model.config.d_conv
|
||||
d_inner = self.model.config.d_inner
|
||||
inference_params = new_inference_params(
|
||||
n_blocks=n_blocks,
|
||||
batch_size=batch_size,
|
||||
d_state=d_state,
|
||||
d_conv=d_conv,
|
||||
d_inner=d_inner,
|
||||
seqlen_offset=seqlen_offset,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
batch.inference_params = inference_params
|
||||
|
||||
# Forward pass
|
||||
logits, speculative_logits = self.forward(
|
||||
input_ids, inference_params=batch.inference_params
|
||||
)
|
||||
|
||||
# batch.inference_params = new_inference_params
|
||||
# Results
|
||||
generations: List[Generation] = []
|
||||
stopped = True
|
||||
|
||||
# Speculation is not active for causal
|
||||
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens,
|
||||
batch.top_n_tokens_tensor,
|
||||
torch.log_softmax(logits[:, -1], -1),
|
||||
accepted_ids,
|
||||
)
|
||||
|
||||
start_decode = time.time_ns()
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.requests,
|
||||
batch.input_lengths,
|
||||
batch.prefix_offsets,
|
||||
batch.read_offsets,
|
||||
logits,
|
||||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
batch.all_input_ids,
|
||||
batch.top_n_tokens,
|
||||
batch_top_token_ids,
|
||||
batch_top_token_logprobs,
|
||||
)
|
||||
|
||||
# For each member of the batch
|
||||
for i, (
|
||||
request,
|
||||
input_length,
|
||||
prefix_offset,
|
||||
read_offset,
|
||||
logits,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
all_input_ids,
|
||||
top_n_tokens,
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(
|
||||
all_input_ids.view(1, -1), logits[-1:, :]
|
||||
)
|
||||
|
||||
# Append next token to all tokens
|
||||
all_input_ids = torch.cat([all_input_ids, next_token_id])
|
||||
new_input_length = input_length + 1
|
||||
|
||||
# Generated token
|
||||
next_token_logprob = logprobs[-1, next_token_id]
|
||||
next_token_id_squeezed = next_token_id.squeeze()
|
||||
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||
all_input_ids[:, 0], prefix_offset, read_offset
|
||||
)
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(
|
||||
next_token_id_squeezed,
|
||||
next_token_text,
|
||||
)
|
||||
|
||||
if not stop:
|
||||
stopped = False
|
||||
|
||||
# Shard generations
|
||||
# All generations will be appended in the rust sharded client
|
||||
if i % self.world_size == self.rank:
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
output_text, _, _ = self.decode_token(
|
||||
all_input_ids[:, 0],
|
||||
prefix_offset=len(all_input_ids)
|
||||
- stopping_criteria.current_tokens
|
||||
- 1,
|
||||
read_offset=len(all_input_ids)
|
||||
- stopping_criteria.current_tokens,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
generated_text = GeneratedText(
|
||||
output_text, stopping_criteria.current_tokens, reason, seed
|
||||
)
|
||||
else:
|
||||
generated_text = None
|
||||
|
||||
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
prefill_logprobs = [float("nan")] + torch.log_softmax(
|
||||
logits, -1
|
||||
).gather(1, all_input_ids[1:]).squeeze(1)[
|
||||
-new_input_length:-1
|
||||
].tolist()
|
||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
prefill_tokens = Tokens(
|
||||
prefill_token_ids,
|
||||
prefill_logprobs,
|
||||
prefill_texts,
|
||||
is_special=[],
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
if top_n_tokens > 0:
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = Tokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
special_toptokens,
|
||||
)
|
||||
else:
|
||||
top_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
Tokens(
|
||||
[next_token_id_squeezed],
|
||||
[next_token_logprob],
|
||||
[next_token_text],
|
||||
[next_token_id_squeezed.item() in self.all_special_ids],
|
||||
),
|
||||
generated_text,
|
||||
top_tokens,
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
|
||||
# Update values
|
||||
batch.next_token_choosers[i] = batch.next_token_choosers[
|
||||
i
|
||||
].advance_grammar(next_token_id_squeezed.item())
|
||||
batch.input_ids[i, 0] = next_token_id
|
||||
batch.all_input_ids[i] = all_input_ids
|
||||
batch.input_lengths[i] = new_input_length
|
||||
batch.prefix_offsets[i] = prefix_offset
|
||||
batch.read_offsets[i] = read_offset
|
||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if stopped:
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, None, (forward_ns, decode_ns)
|
||||
|
||||
# Slice unused values from prefill
|
||||
batch.input_ids = batch.input_ids[:, :1]
|
||||
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, batch, (forward_ns, decode_ns)
|
@ -46,10 +46,17 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
|
||||
aspect_ratio_mask: Optional[torch.Tensor] = None
|
||||
cross_attention_states: Optional[torch.Tensor] = None
|
||||
|
||||
def prepare_for_prefill(
|
||||
self, max_padded_input_len, max_padded_bs, max_total_tokens
|
||||
):
|
||||
super(FlashVlmCausalLMBatch, self).prepare_for_prefill(
|
||||
max_padded_input_len, max_padded_bs, max_total_tokens
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches):
|
||||
batch = super().concatenate(batches)
|
||||
def concatenate(cls, batches, padded_total_bs: int = 0):
|
||||
batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs)
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
|
||||
@ -73,7 +80,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int]):
|
||||
assert self.image_indices is not None
|
||||
batch = super().filter(request_ids)
|
||||
batch = super(FlashVlmCausalLMBatch, self).filter(request_ids)
|
||||
assert self.image_indices is not None
|
||||
indices = []
|
||||
for i, request_id in enumerate(request_ids):
|
||||
@ -99,6 +106,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
|
||||
]
|
||||
else:
|
||||
batch.cross_attention_states = None
|
||||
batch.pixel_values = None
|
||||
return batch
|
||||
|
||||
@classmethod
|
||||
@ -228,6 +236,10 @@ def generate_cross_attention_states(
|
||||
|
||||
|
||||
class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
def set_inputs_embeds(self, batch):
|
||||
# Set the input embeddings to None, as we are using the input_ids for the model
|
||||
batch.inputs_embeds = None
|
||||
|
||||
def warmup_decode(
|
||||
self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch
|
||||
):
|
||||
|
@ -1,71 +0,0 @@
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
import torch
|
||||
import torch.distributed
|
||||
from opentelemetry import trace
|
||||
from typing import Iterable
|
||||
from text_generation_server.models.flash_vlm_causal_lm import (
|
||||
FlashVlmCausalLMBatch,
|
||||
image_text_replacement,
|
||||
)
|
||||
|
||||
from text_generation_server.pb.generate_pb2 import Request
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class PaliGemmaBatch(FlashVlmCausalLMBatch):
|
||||
@classmethod
|
||||
def batch_tokenized_inputs(
|
||||
cls, requests: Iterable[Request], tokenizer, processor, config
|
||||
):
|
||||
batch_inputs = []
|
||||
image_inputs = []
|
||||
max_truncation = 0
|
||||
for r in requests:
|
||||
full_text = ""
|
||||
image_id = 0
|
||||
for chunk in r.input_chunks.chunks:
|
||||
chunk_type = chunk.WhichOneof("chunk")
|
||||
if chunk_type == "text":
|
||||
full_text += "<bos>" + chunk.text + "\n"
|
||||
elif chunk_type == "image":
|
||||
image = Image.open(BytesIO(chunk.image.data))
|
||||
# TODO do_convert_RGB should be on by default ?
|
||||
image = image.convert("RGB")
|
||||
image_input = processor.image_processor(image, return_tensors="pt")
|
||||
full_text += image_text_replacement(
|
||||
processor, image_input, config, image_id
|
||||
)
|
||||
image_inputs.append(image_input)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||
|
||||
batch_inputs.append(full_text)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
batch_inputs,
|
||||
truncation=True,
|
||||
max_length=max_truncation,
|
||||
add_special_tokens=False,
|
||||
)["input_ids"]
|
||||
if image_inputs:
|
||||
image_input = image_inputs[0]
|
||||
new_image_inputs = {
|
||||
"pixel_values": torch.cat(
|
||||
[img["pixel_values"] for img in image_inputs], dim=0
|
||||
),
|
||||
}
|
||||
if "pixel_attention_mask" in image_input:
|
||||
new_image_inputs["pixel_attention_mask"] = torch.cat(
|
||||
[img["pixel_attention_mask"] for img in image_inputs], dim=0
|
||||
)
|
||||
if "image_sizes" in image_input:
|
||||
new_image_inputs["image_sizes"] = torch.cat(
|
||||
[img["image_sizes"] for img in image_inputs], dim=0
|
||||
)
|
||||
image_inputs = new_image_inputs
|
||||
else:
|
||||
image_inputs = None
|
||||
return batch_tokenized_inputs, image_inputs
|
@ -1,47 +0,0 @@
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
|
||||
|
||||
@dataclass
|
||||
class StarCoderCausalLMBatch(CausalLMBatch):
|
||||
past_key_values: Optional[List[torch.Tensor]]
|
||||
|
||||
def detach_kv_cache(self):
|
||||
past_keys = []
|
||||
past_values = []
|
||||
last_dim = int(self.past_key_values[0].size(dim=-1) / 2)
|
||||
for key_value in self.past_key_values:
|
||||
past_keys.append(key_value.split((last_dim, last_dim), dim=-1)[0])
|
||||
past_values.append(key_value.split((last_dim, last_dim), dim=-1)[1])
|
||||
del self.past_key_values
|
||||
|
||||
return past_keys, past_values
|
||||
|
||||
def attach_kv_cache(self, past_keys, past_values):
|
||||
self.past_key_values = [
|
||||
torch.cat((key, value), dim=-1)
|
||||
for key, value in zip(past_keys, past_values)
|
||||
]
|
||||
|
||||
|
||||
class StarCoder(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
|
||||
super(StarCoder, self).__init__(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return StarCoderCausalLMBatch
|
File diff suppressed because it is too large
Load Diff
@ -7,13 +7,5 @@ if [[ "$*" == *"--sharded true"* ]]; then
|
||||
echo 'setting PT_HPU_ENABLE_LAZY_COLLECTIVES=1 for sharding'
|
||||
export PT_HPU_ENABLE_LAZY_COLLECTIVES=1
|
||||
fi
|
||||
# Check if ATTENTION environment variable is set to paged
|
||||
if [[ "$ATTENTION" == "paged" ]]; then
|
||||
# Check if Llama-4 is in the command line arguments
|
||||
if [[ "$*" == *"Llama-4"* || "$*" == *"Qwen3"* ]]; then
|
||||
echo 'ATTENTION=paged and Llama-4 or Qwen3 detected'
|
||||
pip install git+https://github.com/huggingface/transformers.git@29338949
|
||||
fi
|
||||
fi
|
||||
|
||||
text-generation-launcher $@
|
||||
|
@ -7,7 +7,8 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
from optimum.neuron.configuration_utils import NeuronConfig
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
from optimum.neuron import NeuronModelForCausalLM
|
||||
@ -175,6 +176,12 @@ class Slot:
|
||||
self._generation_config.top_p = request.parameters.top_p
|
||||
if request.parameters.typical_p != 0:
|
||||
self._generation_config.typical_p = request.parameters.typical_p
|
||||
else:
|
||||
# Set the sampling parameters to emulate greedy decoding when using on-device sampling
|
||||
self._generation_config.temperature = 1.0
|
||||
self._generation_config.top_k = 1
|
||||
self._generation_config.top_p = 1.0
|
||||
self._generation_config.typical_p = 1.0
|
||||
if request.parameters.repetition_penalty != 0:
|
||||
self._generation_config.repetition_penalty = (
|
||||
request.parameters.repetition_penalty
|
||||
@ -211,19 +218,11 @@ class Slot:
|
||||
self._mask = attention_mask.clone()
|
||||
self._selector = selector
|
||||
|
||||
def pause(self, reset_on_pause: bool):
|
||||
def pause(self):
|
||||
"""Mark the current slot as paused for generation.
|
||||
|
||||
Note that the KV cache for this slot will still be filled.
|
||||
"""
|
||||
if reset_on_pause:
|
||||
# Drop the last token as it will be added back when resuming the slot
|
||||
self._generated_tokens -= 1
|
||||
# Since generated tokens are now part of the prefill, we need to reevaluate
|
||||
# max_new_tokens for the next generation
|
||||
self._generation_config.max_new_tokens = (
|
||||
self._max_new_tokens - self._generated_tokens
|
||||
)
|
||||
self._state = Slot.State.PAUSE
|
||||
|
||||
def resume(self):
|
||||
@ -340,16 +339,27 @@ class NeuronGenerator(Generator):
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
):
|
||||
self.model = model
|
||||
self.rebuild_cache_on_prefill = not self.model.continuous_batching
|
||||
if not isinstance(self.model, NeuronModelForCausalLM):
|
||||
raise ValueError("The model must be a NeuronModelForCausalLM.")
|
||||
if not model.neuron_config.continuous_batching:
|
||||
raise ValueError(
|
||||
"The neuron model must be compiled with continuous_batching=True."
|
||||
)
|
||||
# Specify padding and truncation options for decoder-only architecture
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.truncation_side = "left"
|
||||
self.tokenizer = tokenizer
|
||||
self.special_tokens = self.tokenizer.all_special_ids
|
||||
self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)]
|
||||
self.slots = [
|
||||
Slot(i, tokenizer) for i in range(self.model.neuron_config.batch_size)
|
||||
]
|
||||
self.batch_id = 0
|
||||
|
||||
@property
|
||||
def on_device_sampling(self) -> bool:
|
||||
return getattr(self.model.neuron_config, "on_device_sampling", False)
|
||||
|
||||
@property
|
||||
def info(self) -> InfoResponse:
|
||||
"""Returns the expected InfoResponse."""
|
||||
@ -371,14 +381,22 @@ class NeuronGenerator(Generator):
|
||||
The maximum number of tokens the model supports.
|
||||
"""
|
||||
# Just check that the warmup request parameters match the model capacity
|
||||
batch_size = self.model.batch_size
|
||||
batch_size = self.model.neuron_config.batch_size
|
||||
if len(batch.requests) > batch_size:
|
||||
raise ValueError(
|
||||
f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE."
|
||||
f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model.neuron_config.batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE."
|
||||
)
|
||||
self.prefill(batch)
|
||||
self.clear()
|
||||
return self.model.batch_size * self.model.max_length
|
||||
return (
|
||||
self.model.neuron_config.batch_size
|
||||
* self.model.neuron_config.sequence_length
|
||||
)
|
||||
|
||||
def max_prefill_length(self) -> int:
|
||||
if hasattr(self.model.neuron_config, "max_context_length"):
|
||||
return self.model.neuron_config.max_context_length
|
||||
return self.model.neuron_config.sequence_length
|
||||
|
||||
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
|
||||
"""Prefill new requests.
|
||||
@ -398,7 +416,7 @@ class NeuronGenerator(Generator):
|
||||
if len(empty_slots) < len(batch.requests):
|
||||
raise ValueError(
|
||||
f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots."
|
||||
f" Please align max_batch_size with the static batch size: {self.model.batch_size}."
|
||||
f" Please align max_batch_size with the static batch size: {self.model.neuron_config.batch_size}."
|
||||
)
|
||||
# Assign each request to an empty slot
|
||||
logger.debug(
|
||||
@ -412,14 +430,8 @@ class NeuronGenerator(Generator):
|
||||
logger.debug(
|
||||
f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
|
||||
)
|
||||
if self.rebuild_cache_on_prefill:
|
||||
# We will clear pending slots and prefill all slots
|
||||
prefill_slots = self.slots
|
||||
seq_ids = None
|
||||
else:
|
||||
# We only need to pass inputs for the new requests
|
||||
prefill_slots = new_slots
|
||||
seq_ids = torch.tensor([slot.id for slot in prefill_slots])
|
||||
prefill_slots = new_slots
|
||||
seq_ids = torch.tensor([slot.id for slot in prefill_slots])
|
||||
# Reconstruct the full inputs (without padding) as seen by the model.
|
||||
# This comprises:
|
||||
# - the inputs for new requests,
|
||||
@ -431,8 +443,10 @@ class NeuronGenerator(Generator):
|
||||
inputs.append(slot.cached_text)
|
||||
# Apply truncation, making sure we fit into static dimensions
|
||||
if slot.truncate == 0:
|
||||
max_length = self.model.max_length
|
||||
elif slot.truncate > max_length and slot.truncate < self.model.max_length:
|
||||
max_length = self.max_prefill_length()
|
||||
elif (
|
||||
slot.truncate > max_length and slot.truncate < self.max_prefill_length()
|
||||
):
|
||||
max_length = slot.truncate
|
||||
# Tokenize with padding and truncation
|
||||
padded_inputs = self.tokenizer(
|
||||
@ -444,13 +458,12 @@ class NeuronGenerator(Generator):
|
||||
)
|
||||
input_ids = padded_inputs.input_ids
|
||||
attention_mask = padded_inputs.attention_mask
|
||||
sampling_params = (
|
||||
torch.zeros(input_ids.shape[0], 3) if self.on_device_sampling else None
|
||||
)
|
||||
# Pause previously active slots during generation
|
||||
next_tokens = []
|
||||
for slot in active_slots:
|
||||
slot.pause(reset_on_pause=self.rebuild_cache_on_prefill)
|
||||
if self.rebuild_cache_on_prefill:
|
||||
# The slot will be reset, so we need to store its next token
|
||||
next_tokens.append(slot.next_token)
|
||||
slot.pause()
|
||||
# Each slot must be reset with the padded inputs and masks
|
||||
for i, slot in enumerate(prefill_slots):
|
||||
if slot.state != slot.state.EMPTY:
|
||||
@ -464,29 +477,33 @@ class NeuronGenerator(Generator):
|
||||
slot_input_ids,
|
||||
slot.generation_config,
|
||||
self.model,
|
||||
self.model.max_length,
|
||||
self.model.neuron_config.sequence_length,
|
||||
tokenizer=self.tokenizer,
|
||||
seed=slot.seed,
|
||||
)
|
||||
slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
|
||||
slot_attention_mask = attention_mask[i]
|
||||
slot.reset(slot_input_ids, slot_attention_mask, selector)
|
||||
if sampling_params is not None:
|
||||
sampling_params[i, 0] = slot.generation_config.top_k
|
||||
sampling_params[i, 1] = slot.generation_config.top_p
|
||||
sampling_params[i, 2] = slot.generation_config.temperature
|
||||
# Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored,
|
||||
# as they have already been generated and sent back in the last decode.
|
||||
model_inputs = self.model.prepare_inputs_for_prefill(
|
||||
input_ids, attention_mask, seq_ids
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
seq_ids=seq_ids,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
logits = self.model(**model_inputs)[0]
|
||||
tokens_or_logits = self.model(**model_inputs)[0]
|
||||
generation, next_batch = self._generate_token(
|
||||
prefill_slots, self.batch_id, logits, input_ids
|
||||
prefill_slots, self.batch_id, tokens_or_logits, input_ids
|
||||
)
|
||||
self.batch_id += 1
|
||||
# Reactivate previously active slots for the next decode
|
||||
for i, slot in enumerate(active_slots):
|
||||
slot.resume()
|
||||
if self.rebuild_cache_on_prefill:
|
||||
# Append back the next token
|
||||
slot.append(next_tokens[i])
|
||||
logger.debug("Model ready for decoding")
|
||||
if next_batch is not None:
|
||||
logger.debug(
|
||||
@ -530,12 +547,8 @@ class NeuronGenerator(Generator):
|
||||
raise ValueError(
|
||||
"Unable to decode tokens for non-prefilled batches (probably due to a previous failure)"
|
||||
)
|
||||
if self.model.continuous_batching:
|
||||
decode_slots = active_slots
|
||||
seq_ids = torch.tensor([slot.id for slot in decode_slots])
|
||||
else:
|
||||
decode_slots = self.slots
|
||||
seq_ids = None
|
||||
decode_slots = active_slots
|
||||
seq_ids = torch.tensor([slot.id for slot in decode_slots])
|
||||
# Reconstruct input_ids and attention_mask from decode slots
|
||||
n_slots = len(decode_slots)
|
||||
input_ids = torch.full(
|
||||
@ -545,22 +558,32 @@ class NeuronGenerator(Generator):
|
||||
for slot in decode_slots:
|
||||
max_length = max(max_length, slot.attention_mask.size(-1))
|
||||
attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64)
|
||||
sampling_params = torch.zeros(n_slots, 3) if self.on_device_sampling else None
|
||||
for i, slot in enumerate(decode_slots):
|
||||
if slot.state != Slot.State.EMPTY:
|
||||
# input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
|
||||
input_ids[i, 0] = slot.next_token
|
||||
attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask
|
||||
if sampling_params is not None:
|
||||
sampling_params[i, 0] = slot.generation_config.top_k
|
||||
sampling_params[i, 1] = slot.generation_config.top_p
|
||||
sampling_params[i, 2] = slot.generation_config.temperature
|
||||
model_inputs = self.model.prepare_inputs_for_decode(
|
||||
input_ids, attention_mask, seq_ids
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
seq_ids=seq_ids,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
tokens_or_logits = self.model(**model_inputs)[0]
|
||||
return self._generate_token(
|
||||
decode_slots, next_batch_id, tokens_or_logits, input_ids
|
||||
)
|
||||
logits = self.model(**model_inputs)[0]
|
||||
return self._generate_token(decode_slots, next_batch_id, logits, input_ids)
|
||||
|
||||
def _generate_token(
|
||||
self,
|
||||
slots: List[Slot],
|
||||
next_batch_id: int,
|
||||
logits: torch.Tensor,
|
||||
tokens_or_logits: torch.Tensor,
|
||||
input_ids: torch.LongTensor,
|
||||
) -> Tuple[List[Generation], CachedBatch]:
|
||||
generations = []
|
||||
@ -569,9 +592,12 @@ class NeuronGenerator(Generator):
|
||||
if slot.state != Slot.State.READY:
|
||||
continue
|
||||
request_id = slot.request_id
|
||||
next_token_logits = logits[i : i + 1, -1, :]
|
||||
slot_input_ids = input_ids[i : i + 1, :]
|
||||
next_token = slot.select(slot_input_ids, next_token_logits)
|
||||
if self.on_device_sampling:
|
||||
next_token = tokens_or_logits[i]
|
||||
else:
|
||||
next_token_logits = tokens_or_logits[i : i + 1, -1, :]
|
||||
next_token = slot.select(slot_input_ids, next_token_logits)
|
||||
next_token_text = slot.append(next_token)
|
||||
generated_text = None
|
||||
finish_reason = None
|
||||
@ -622,7 +648,7 @@ class NeuronGenerator(Generator):
|
||||
|
||||
def _cached_batch(self, batch_id: int, request_ids: List):
|
||||
size = len(request_ids)
|
||||
max_tokens = size * self.model.max_length
|
||||
max_tokens = size * self.model.neuron_config.sequence_length
|
||||
return CachedBatch(
|
||||
id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens
|
||||
)
|
||||
@ -671,8 +697,16 @@ class NeuronGenerator(Generator):
|
||||
Returns:
|
||||
A NeuronGenerator.
|
||||
"""
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
neuron_config = getattr(config, "neuron", None)
|
||||
try:
|
||||
neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
|
||||
model_id,
|
||||
revision,
|
||||
e,
|
||||
)
|
||||
neuron_config = None
|
||||
start = time.time()
|
||||
if neuron_config is None:
|
||||
export_kwargs = get_export_kwargs_from_env()
|
||||
|
@ -6,10 +6,12 @@ from typing import Optional
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.constants import HF_HUB_CACHE
|
||||
from loguru import logger
|
||||
from transformers import AutoConfig
|
||||
|
||||
from optimum.neuron import NeuronModelForCausalLM
|
||||
from optimum.neuron.utils import get_hub_cached_entries
|
||||
from optimum.neuron.cache import get_hub_cached_entries
|
||||
from optimum.neuron.configuration_utils import NeuronConfig
|
||||
|
||||
|
||||
from .tgi_env import check_env_and_neuron_config_compatibility
|
||||
|
||||
|
||||
def get_export_kwargs_from_env():
|
||||
@ -24,7 +26,6 @@ def get_export_kwargs_from_env():
|
||||
num_cores = int(num_cores)
|
||||
auto_cast_type = os.environ.get("HF_AUTO_CAST_TYPE", None)
|
||||
return {
|
||||
"task": "text-generation",
|
||||
"batch_size": batch_size,
|
||||
"sequence_length": sequence_length,
|
||||
"num_cores": num_cores,
|
||||
@ -32,20 +33,15 @@ def get_export_kwargs_from_env():
|
||||
}
|
||||
|
||||
|
||||
def is_cached(model_id, neuron_config):
|
||||
def is_cached(model_id):
|
||||
# Look for cached entries for the specified model
|
||||
in_cache = False
|
||||
entries = get_hub_cached_entries(model_id, "inference")
|
||||
entries = get_hub_cached_entries(model_id)
|
||||
# Look for compatible entries
|
||||
for entry in entries:
|
||||
compatible = True
|
||||
for key, value in neuron_config.items():
|
||||
# Only weights can be different
|
||||
if key in ["checkpoint_id", "checkpoint_revision"]:
|
||||
continue
|
||||
if entry[key] != value:
|
||||
compatible = False
|
||||
if compatible:
|
||||
if check_env_and_neuron_config_compatibility(
|
||||
entry, check_compiler_version=True
|
||||
):
|
||||
in_cache = True
|
||||
break
|
||||
return in_cache
|
||||
@ -87,8 +83,16 @@ def fetch_model(
|
||||
revision = None
|
||||
# Download the model from the Hub (HUGGING_FACE_HUB_TOKEN must be set for a private or gated model)
|
||||
# Note that the model may already be present in the cache.
|
||||
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
||||
neuron_config = getattr(config, "neuron", None)
|
||||
try:
|
||||
neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
|
||||
model_id,
|
||||
revision,
|
||||
e,
|
||||
)
|
||||
neuron_config = None
|
||||
if neuron_config is not None:
|
||||
if os.path.isdir(model_id):
|
||||
return model_id
|
||||
@ -99,16 +103,11 @@ def fetch_model(
|
||||
log_cache_size()
|
||||
return snapshot_download(model_id, revision=revision, ignore_patterns="*.bin")
|
||||
# Model needs to be exported: look for compatible cached entries on the hub
|
||||
export_kwargs = get_export_kwargs_from_env()
|
||||
export_config = NeuronModelForCausalLM.get_export_config(
|
||||
model_id, config, revision=revision, **export_kwargs
|
||||
)
|
||||
neuron_config = export_config.neuron
|
||||
if not is_cached(model_id, neuron_config):
|
||||
if not is_cached(model_id):
|
||||
hub_cache_url = "https://huggingface.co/aws-neuron/optimum-neuron-cache"
|
||||
neuron_export_url = "https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-neuronx-tgi"
|
||||
error_msg = (
|
||||
f"No cached version found for {model_id} with {neuron_config}."
|
||||
f"No cached version found for {model_id} with {get_export_kwargs_from_env()}."
|
||||
f"You can start a discussion to request it on {hub_cache_url}"
|
||||
f"Alternatively, you can export your own neuron model as explained in {neuron_export_url}"
|
||||
)
|
||||
@ -121,8 +120,10 @@ def fetch_model(
|
||||
# Prefetch weights, tokenizer and generation config so that they are in cache
|
||||
log_cache_size()
|
||||
start = time.time()
|
||||
snapshot_download(model_id, revision=revision, ignore_patterns="*.bin")
|
||||
snapshot_path = snapshot_download(
|
||||
model_id, revision=revision, ignore_patterns="*.bin"
|
||||
)
|
||||
end = time.time()
|
||||
logger.info(f"Model weights fetched in {end - start:.2f} s.")
|
||||
log_cache_size()
|
||||
return model_id
|
||||
return snapshot_path
|
||||
|
145
backends/neuron/tgi_env.py → backends/neuron/server/text_generation_server/tgi_env.py
Executable file → Normal file
145
backends/neuron/tgi_env.py → backends/neuron/server/text_generation_server/tgi_env.py
Executable file → Normal file
@ -6,12 +6,11 @@ import os
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from huggingface_hub import constants
|
||||
from transformers import AutoConfig
|
||||
|
||||
from optimum.neuron.modeling_decoder import get_available_cores
|
||||
from optimum.neuron.utils import get_hub_cached_entries
|
||||
from optimum.neuron.cache import get_hub_cached_entries
|
||||
from optimum.neuron.configuration_utils import NeuronConfig
|
||||
from optimum.neuron.utils.version_utils import get_neuronxcc_version
|
||||
from optimum.neuron.utils import map_torch_dtype
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -24,15 +23,9 @@ tgi_router_env_vars = [
|
||||
]
|
||||
tgi_server_env_vars = ["HF_NUM_CORES", "HF_AUTO_CAST_TYPE"]
|
||||
|
||||
env_config_peering = [
|
||||
("MAX_BATCH_SIZE", "batch_size"),
|
||||
("MAX_TOTAL_TOKENS", "sequence_length"),
|
||||
("HF_AUTO_CAST_TYPE", "auto_cast_type"),
|
||||
("HF_NUM_CORES", "num_cores"),
|
||||
]
|
||||
|
||||
# By the end of this script all env var should be specified properly
|
||||
env_vars = tgi_server_env_vars + tgi_router_env_vars
|
||||
tgi_env_vars = tgi_server_env_vars + tgi_router_env_vars
|
||||
|
||||
available_cores = get_available_cores()
|
||||
neuronxcc_version = get_neuronxcc_version()
|
||||
@ -93,9 +86,17 @@ def parse_cmdline_and_set_env(argv: List[str] = None) -> argparse.Namespace:
|
||||
|
||||
|
||||
def neuron_config_to_env(neuron_config):
|
||||
if isinstance(neuron_config, NeuronConfig):
|
||||
neuron_config = neuron_config.to_dict()
|
||||
with open(os.environ["ENV_FILEPATH"], "w") as f:
|
||||
for env_var, config_key in env_config_peering:
|
||||
f.write("export {}={}\n".format(env_var, neuron_config[config_key]))
|
||||
f.write("export MAX_BATCH_SIZE={}\n".format(neuron_config["batch_size"]))
|
||||
f.write("export MAX_TOTAL_TOKENS={}\n".format(neuron_config["sequence_length"]))
|
||||
f.write("export HF_NUM_CORES={}\n".format(neuron_config["tp_degree"]))
|
||||
config_key = (
|
||||
"auto_cast_type" if "auto_cast_type" in neuron_config else "torch_dtype"
|
||||
)
|
||||
auto_cast_type = neuron_config[config_key]
|
||||
f.write("export HF_AUTO_CAST_TYPE={}\n".format(auto_cast_type))
|
||||
max_input_tokens = os.getenv("MAX_INPUT_TOKENS")
|
||||
if not max_input_tokens:
|
||||
max_input_tokens = int(neuron_config["sequence_length"]) // 2
|
||||
@ -111,7 +112,7 @@ def neuron_config_to_env(neuron_config):
|
||||
|
||||
|
||||
def sort_neuron_configs(dictionary):
|
||||
return -dictionary["num_cores"], -dictionary["batch_size"]
|
||||
return -dictionary["tp_degree"], -dictionary["batch_size"]
|
||||
|
||||
|
||||
def lookup_compatible_cached_model(
|
||||
@ -119,7 +120,7 @@ def lookup_compatible_cached_model(
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
# Reuse the same mechanic as the one in use to configure the tgi server part
|
||||
# The only difference here is that we stay as flexible as possible on the compatibility part
|
||||
entries = get_hub_cached_entries(model_id, "inference")
|
||||
entries = get_hub_cached_entries(model_id)
|
||||
|
||||
logger.debug(
|
||||
"Found %d cached entries for model %s, revision %s",
|
||||
@ -155,15 +156,15 @@ def lookup_compatible_cached_model(
|
||||
|
||||
|
||||
def check_env_and_neuron_config_compatibility(
|
||||
neuron_config: Dict[str, Any], check_compiler_version: bool
|
||||
neuron_config_dict: Dict[str, Any], check_compiler_version: bool
|
||||
) -> bool:
|
||||
logger.debug(
|
||||
"Checking the provided neuron config %s is compatible with the local setup and provided environment",
|
||||
neuron_config,
|
||||
neuron_config_dict,
|
||||
)
|
||||
|
||||
# Local setup compat checks
|
||||
if neuron_config["num_cores"] > available_cores:
|
||||
if neuron_config_dict["tp_degree"] > available_cores:
|
||||
logger.debug(
|
||||
"Not enough neuron cores available to run the provided neuron config"
|
||||
)
|
||||
@ -171,33 +172,65 @@ def check_env_and_neuron_config_compatibility(
|
||||
|
||||
if (
|
||||
check_compiler_version
|
||||
and neuron_config["compiler_version"] != neuronxcc_version
|
||||
and neuron_config_dict["neuronxcc_version"] != neuronxcc_version
|
||||
):
|
||||
logger.debug(
|
||||
"Compiler version conflict, the local one (%s) differs from the one used to compile the model (%s)",
|
||||
neuronxcc_version,
|
||||
neuron_config["compiler_version"],
|
||||
neuron_config_dict["neuronxcc_version"],
|
||||
)
|
||||
return False
|
||||
|
||||
for env_var, config_key in env_config_peering:
|
||||
neuron_config_value = str(neuron_config[config_key])
|
||||
env_value = os.getenv(env_var, str(neuron_config_value))
|
||||
batch_size = os.getenv("MAX_BATCH_SIZE", None)
|
||||
if batch_size is not None and neuron_config_dict["batch_size"] < int(batch_size):
|
||||
logger.debug(
|
||||
"The provided MAX_BATCH_SIZE (%s) is higher than the neuron config batch size (%s)",
|
||||
os.getenv("MAX_BATCH_SIZE"),
|
||||
neuron_config_dict["batch_size"],
|
||||
)
|
||||
return False
|
||||
max_total_tokens = os.getenv("MAX_TOTAL_TOKENS", None)
|
||||
if max_total_tokens is not None and neuron_config_dict["sequence_length"] < int(
|
||||
max_total_tokens
|
||||
):
|
||||
logger.debug(
|
||||
"The provided MAX_TOTAL_TOKENS (%s) is higher than the neuron config sequence length (%s)",
|
||||
max_total_tokens,
|
||||
neuron_config_dict["sequence_length"],
|
||||
)
|
||||
return False
|
||||
num_cores = os.getenv("HF_NUM_CORES", None)
|
||||
if num_cores is not None and neuron_config_dict["tp_degree"] < int(num_cores):
|
||||
logger.debug(
|
||||
"The provided HF_NUM_CORES (%s) is higher than the neuron config tp degree (%s)",
|
||||
num_cores,
|
||||
neuron_config_dict["tp_degree"],
|
||||
)
|
||||
return False
|
||||
auto_cast_type = os.getenv("HF_AUTO_CAST_TYPE", None)
|
||||
if auto_cast_type is not None:
|
||||
config_key = (
|
||||
"auto_cast_type"
|
||||
if "auto_cast_type" in neuron_config_dict
|
||||
else "torch_dtype"
|
||||
)
|
||||
neuron_config_value = map_torch_dtype(str(neuron_config_dict[config_key]))
|
||||
env_value = map_torch_dtype(auto_cast_type)
|
||||
if env_value != neuron_config_value:
|
||||
logger.debug(
|
||||
"The provided env var '%s' and the neuron config '%s' param differ (%s != %s)",
|
||||
env_var,
|
||||
config_key,
|
||||
"The provided auto cast type and the neuron config param differ (%s != %s)",
|
||||
env_value,
|
||||
neuron_config_value,
|
||||
)
|
||||
return False
|
||||
|
||||
max_input_tokens = int(
|
||||
os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0))
|
||||
)
|
||||
if max_input_tokens > 0:
|
||||
sequence_length = neuron_config["sequence_length"]
|
||||
if hasattr(neuron_config_dict, "max_context_length"):
|
||||
sequence_length = neuron_config_dict["max_context_length"]
|
||||
else:
|
||||
sequence_length = neuron_config_dict["sequence_length"]
|
||||
if max_input_tokens >= sequence_length:
|
||||
logger.debug(
|
||||
"Specified max input tokens is not compatible with config sequence length ( %s >= %s)",
|
||||
@ -211,38 +244,29 @@ def check_env_and_neuron_config_compatibility(
|
||||
|
||||
def get_env_dict() -> Dict[str, str]:
|
||||
d = {}
|
||||
for k in env_vars:
|
||||
for k in tgi_env_vars:
|
||||
d[k] = os.getenv(k)
|
||||
return d
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
This script determines proper default TGI env variables for the neuron precompiled models to
|
||||
work properly
|
||||
:return:
|
||||
"""
|
||||
args = parse_cmdline_and_set_env()
|
||||
|
||||
for env_var in env_vars:
|
||||
if not os.getenv(env_var):
|
||||
break
|
||||
else:
|
||||
logger.info(
|
||||
"All env vars %s already set, skipping, user know what they are doing",
|
||||
env_vars,
|
||||
def get_neuron_config_for_model(
|
||||
model_name_or_path: str, revision: Optional[str] = None
|
||||
) -> NeuronConfig:
|
||||
try:
|
||||
neuron_config = NeuronConfig.from_pretrained(
|
||||
model_name_or_path, revision=revision
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
cache_dir = constants.HF_HUB_CACHE
|
||||
|
||||
logger.info("Cache dir %s, model %s", cache_dir, args.model_id)
|
||||
|
||||
config = AutoConfig.from_pretrained(args.model_id, revision=args.revision)
|
||||
neuron_config = getattr(config, "neuron", None)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
|
||||
model_name_or_path,
|
||||
revision,
|
||||
e,
|
||||
)
|
||||
neuron_config = None
|
||||
if neuron_config is not None:
|
||||
compatible = check_env_and_neuron_config_compatibility(
|
||||
neuron_config, check_compiler_version=False
|
||||
neuron_config.to_dict(), check_compiler_version=False
|
||||
)
|
||||
if not compatible:
|
||||
env_dict = get_env_dict()
|
||||
@ -252,17 +276,6 @@ def main():
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
else:
|
||||
neuron_config = lookup_compatible_cached_model(args.model_id, args.revision)
|
||||
neuron_config = lookup_compatible_cached_model(model_name_or_path, revision)
|
||||
|
||||
if not neuron_config:
|
||||
msg = (
|
||||
"No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}"
|
||||
).format(get_env_dict(), available_cores, neuronxcc_version)
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
|
||||
neuron_config_to_env(neuron_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
return neuron_config
|
74
backends/neuron/tests/fixtures/model.py
vendored
74
backends/neuron/tests/fixtures/model.py
vendored
@ -4,14 +4,12 @@ import subprocess
|
||||
import sys
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import huggingface_hub
|
||||
import os
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from optimum.neuron import NeuronModelForCausalLM
|
||||
from optimum.neuron.utils import synchronize_hub_cache
|
||||
from optimum.neuron.version import __sdk_version__ as sdk_version
|
||||
from optimum.neuron.version import __version__ as version
|
||||
|
||||
from optimum.neuron.cache import synchronize_hub_cache
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
@ -21,30 +19,14 @@ logging.basicConfig(
|
||||
)
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
OPTIMUM_CACHE_REPO_ID = "optimum-internal-testing/neuron-testing-cache"
|
||||
|
||||
|
||||
# All model configurations below will be added to the neuron_model_config fixture
|
||||
MODEL_CONFIGURATIONS = {
|
||||
"gpt2": {
|
||||
"model_id": "gpt2",
|
||||
"export_kwargs": {
|
||||
"batch_size": 4,
|
||||
"sequence_length": 1024,
|
||||
"num_cores": 2,
|
||||
"auto_cast_type": "fp16",
|
||||
},
|
||||
},
|
||||
"llama": {
|
||||
"model_id": "NousResearch/Hermes-2-Theta-Llama-3-8B",
|
||||
"export_kwargs": {
|
||||
"batch_size": 4,
|
||||
"sequence_length": 2048,
|
||||
"num_cores": 2,
|
||||
"auto_cast_type": "fp16",
|
||||
},
|
||||
},
|
||||
"mistral": {
|
||||
"model_id": "optimum/mistral-1.1b-testing",
|
||||
"model_id": "unsloth/Llama-3.2-1B-Instruct",
|
||||
"export_kwargs": {
|
||||
"batch_size": 4,
|
||||
"sequence_length": 4096,
|
||||
@ -58,7 +40,7 @@ MODEL_CONFIGURATIONS = {
|
||||
"batch_size": 4,
|
||||
"sequence_length": 4096,
|
||||
"num_cores": 2,
|
||||
"auto_cast_type": "fp16",
|
||||
"auto_cast_type": "bf16",
|
||||
},
|
||||
},
|
||||
"granite": {
|
||||
@ -73,12 +55,6 @@ MODEL_CONFIGURATIONS = {
|
||||
}
|
||||
|
||||
|
||||
def get_hub_neuron_model_id(config_name: str):
|
||||
return (
|
||||
f"optimum-internal-testing/neuron-testing-{version}-{sdk_version}-{config_name}"
|
||||
)
|
||||
|
||||
|
||||
def export_model(model_id, export_kwargs, neuron_model_path):
|
||||
export_command = [
|
||||
"optimum-cli",
|
||||
@ -104,57 +80,35 @@ def export_model(model_id, export_kwargs, neuron_model_path):
|
||||
def neuron_model_config(request):
|
||||
"""Expose a pre-trained neuron model
|
||||
|
||||
The fixture first makes sure the following model artifacts are present on the hub:
|
||||
- exported neuron model under optimum-internal-testing/neuron-testing-<version>-<name>,
|
||||
- cached artifacts under optimum-internal-testing/neuron-testing-cache.
|
||||
If not, it will export the model and push it to the hub.
|
||||
|
||||
It then fetches the model locally and return a dictionary containing:
|
||||
The fixture exports a model locally and returns a dictionary containing:
|
||||
- a configuration name,
|
||||
- the original model id,
|
||||
- the export parameters,
|
||||
- the neuron model id,
|
||||
- the neuron model local path.
|
||||
|
||||
For each exposed model, the local directory is maintained for the duration of the
|
||||
test session and cleaned up afterwards.
|
||||
The hub model artifacts are never cleaned up and persist accross sessions.
|
||||
They must be cleaned up manually when the optimum-neuron version changes.
|
||||
|
||||
"""
|
||||
config_name = request.param
|
||||
model_config = copy.deepcopy(MODEL_CONFIGURATIONS[request.param])
|
||||
model_id = model_config["model_id"]
|
||||
export_kwargs = model_config["export_kwargs"]
|
||||
neuron_model_id = get_hub_neuron_model_id(config_name)
|
||||
with TemporaryDirectory() as neuron_model_path:
|
||||
hub = huggingface_hub.HfApi()
|
||||
if hub.repo_exists(neuron_model_id):
|
||||
logger.info(f"Fetching {neuron_model_id} from the HuggingFace hub")
|
||||
hub.snapshot_download(neuron_model_id, local_dir=neuron_model_path)
|
||||
else:
|
||||
export_model(model_id, export_kwargs, neuron_model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
tokenizer.save_pretrained(neuron_model_path)
|
||||
del tokenizer
|
||||
# Create the test model on the hub
|
||||
hub.create_repo(neuron_model_id, private=True)
|
||||
hub.upload_folder(
|
||||
folder_path=neuron_model_path,
|
||||
repo_id=neuron_model_id,
|
||||
ignore_patterns=[NeuronModelForCausalLM.CHECKPOINT_DIR + "/*"],
|
||||
)
|
||||
# Make sure it is cached
|
||||
synchronize_hub_cache(cache_repo_id=OPTIMUM_CACHE_REPO_ID)
|
||||
export_model(model_id, export_kwargs, neuron_model_path)
|
||||
synchronize_hub_cache(cache_repo_id=OPTIMUM_CACHE_REPO_ID)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
tokenizer.save_pretrained(neuron_model_path)
|
||||
del tokenizer
|
||||
# Add dynamic parameters to the model configuration
|
||||
model_config["neuron_model_path"] = neuron_model_path
|
||||
model_config["neuron_model_id"] = neuron_model_id
|
||||
# Also add model configuration name to allow tests to adapt their expectations
|
||||
model_config["name"] = config_name
|
||||
# Yield instead of returning to keep a reference to the temporary directory.
|
||||
# It will go out of scope and be released only once all tests needing the fixture
|
||||
# have been completed.
|
||||
logger.info(f"{config_name} ready for testing ...")
|
||||
os.environ["CUSTOM_CACHE_REPO"] = OPTIMUM_CACHE_REPO_ID
|
||||
yield model_config
|
||||
logger.info(f"Done with {config_name}")
|
||||
|
||||
|
42
backends/neuron/tests/server/test_cached_model.py
Normal file
42
backends/neuron/tests/server/test_cached_model.py
Normal file
@ -0,0 +1,42 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from text_generation_server.generator import NeuronGenerator
|
||||
from text_generation_server.model import fetch_model, is_cached
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def cached_model_id(neuron_model_config) -> str:
|
||||
"""
|
||||
Fixture to provide a cached model ID for testing.
|
||||
This assumes the model is already cached in the local environment.
|
||||
"""
|
||||
export_kwargs = neuron_model_config["export_kwargs"]
|
||||
os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"])
|
||||
os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"])
|
||||
os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"]
|
||||
os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"])
|
||||
yield neuron_model_config["model_id"]
|
||||
os.environ.pop("MAX_BATCH_SIZE", None)
|
||||
os.environ.pop("MAX_TOTAL_TOKENS", None)
|
||||
os.environ.pop("HF_AUTO_CAST_TYPE", None)
|
||||
os.environ.pop("HF_NUM_CORES", None)
|
||||
|
||||
|
||||
def test_model_is_cached(cached_model_id):
|
||||
assert is_cached(cached_model_id), f"Model {cached_model_id} is not cached"
|
||||
|
||||
|
||||
def test_fetch_cached_model(cached_model_id: str):
|
||||
model_path = fetch_model(cached_model_id)
|
||||
assert os.path.exists(
|
||||
model_path
|
||||
), f"Model {cached_model_id} was not fetched successfully"
|
||||
assert os.path.isdir(model_path), f"Model {cached_model_id} is not a directory"
|
||||
|
||||
|
||||
def test_generator_from_cached_model(cached_model_id: str):
|
||||
generator = NeuronGenerator.from_pretrained(model_id=cached_model_id)
|
||||
assert generator is not None, "Generator could not be created from cached model"
|
||||
assert generator.model is not None, "Generator model is not initialized"
|
||||
assert generator.tokenizer is not None, "Generator tokenizer is not initialized"
|
@ -9,13 +9,13 @@ def test_continuous_batching_two_requests(neuron_model_config):
|
||||
"""
|
||||
neuron_model_path = neuron_model_config["neuron_model_path"]
|
||||
generator = NeuronGenerator.from_pretrained(neuron_model_path)
|
||||
assert generator.model.batch_size > 1
|
||||
assert generator.model.neuron_config.batch_size > 1
|
||||
input_text = "Once upon a time"
|
||||
max_new_tokens = 20
|
||||
# Prefill a single request, remembering the generated token
|
||||
tokens = {0: [], 1: []}
|
||||
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens)
|
||||
max_length = generator.model.max_length
|
||||
max_length = generator.model.neuron_config.sequence_length
|
||||
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
|
||||
generations, next_batch = generator.prefill(batch)
|
||||
assert next_batch.size == 1
|
||||
|
@ -23,7 +23,7 @@ def _test_decode(config_name, generator, do_sample):
|
||||
request = create_request(
|
||||
id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample
|
||||
)
|
||||
max_length = generator.model.max_length
|
||||
max_length = generator.model.neuron_config.sequence_length
|
||||
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
|
||||
generations, next_batch = generator.prefill(batch)
|
||||
# We already generated one token: call decode max_new_tokens - 1 times
|
||||
@ -40,19 +40,15 @@ def _test_decode(config_name, generator, do_sample):
|
||||
assert output.finish_reason == 0
|
||||
if do_sample:
|
||||
expected_text = {
|
||||
"gpt2": " The sun was set",
|
||||
"llama": "George Orwell, 1984",
|
||||
"mistral": "The sky was",
|
||||
"qwen2": " A young woman with",
|
||||
"llama": " I sat alone in the café",
|
||||
"qwen2": " The air was so still",
|
||||
"granite": "1984, George Orwell",
|
||||
}[config_name]
|
||||
assert expected_text in output.text
|
||||
else:
|
||||
print(output.text)
|
||||
expected_text = {
|
||||
"gpt2": '\n\n"I\'m going to go to bed," I said.\n\n"I\'m going',
|
||||
"llama": " George Orwell’s classic dystopian novel, 1984, begins with this ominous sentence. The story",
|
||||
"mistral": "\nThe clocks were striking thirteen.\nThe clocks were striking thirteen.",
|
||||
"llama": " The world was holding its breath as the world's top scientists and engineers gathered at the secret underground facility",
|
||||
"qwen2": " I was sitting in my room, staring at the ceiling, when the door opened and in came a",
|
||||
"granite": "\n\nThis opening line from George Orwell's dystopian novel \"198",
|
||||
}[config_name]
|
||||
|
@ -9,7 +9,7 @@ def test_prefill(neuron_model_config):
|
||||
neuron_model_path = neuron_model_config["neuron_model_path"]
|
||||
generator = NeuronGenerator.from_pretrained(neuron_model_path)
|
||||
max_batch_size = 4
|
||||
assert generator.model.batch_size >= max_batch_size
|
||||
assert generator.model.neuron_config.batch_size >= max_batch_size
|
||||
for num_requests in [1, max_batch_size]:
|
||||
for do_sample in [True, False]:
|
||||
mode = "sample" if do_sample else "greedy"
|
||||
@ -34,7 +34,7 @@ def _test_prefill(config_name, generator, batch_size, do_sample):
|
||||
)
|
||||
)
|
||||
# Let's be pessimistic when estimating max_tokens
|
||||
max_length = generator.model.max_length
|
||||
max_length = generator.max_prefill_length()
|
||||
batch = Batch(
|
||||
id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
|
||||
)
|
||||
@ -46,17 +46,13 @@ def _test_prefill(config_name, generator, batch_size, do_sample):
|
||||
assert len(generations) == batch_size
|
||||
if do_sample:
|
||||
expectations = {
|
||||
"gpt2": [383, " The"],
|
||||
"llama": [10058, " George"],
|
||||
"mistral": [450, " The"],
|
||||
"qwen2": [362, " A"],
|
||||
"llama": [358, " I"],
|
||||
"qwen2": [576, " The"],
|
||||
"granite": [308, " ("],
|
||||
}[config_name]
|
||||
else:
|
||||
expectations = {
|
||||
"gpt2": [198, "\n"],
|
||||
"llama": [10058, " George"],
|
||||
"mistral": [13, "\n"],
|
||||
"llama": [578, " The"],
|
||||
"qwen2": [358, " I"],
|
||||
"granite": [203, "\n"],
|
||||
}[config_name]
|
||||
@ -70,7 +66,7 @@ def test_prefill_truncate(neuron_model_config):
|
||||
config_name = neuron_model_config["name"]
|
||||
neuron_model_path = neuron_model_config["neuron_model_path"]
|
||||
generator = NeuronGenerator.from_pretrained(neuron_model_path)
|
||||
batch_size = generator.model.batch_size
|
||||
batch_size = generator.model.neuron_config.batch_size
|
||||
# We apply truncation to all requests but the first one
|
||||
truncate = [
|
||||
None,
|
||||
@ -83,7 +79,7 @@ def test_prefill_truncate(neuron_model_config):
|
||||
requests = []
|
||||
for i in range(batch_size):
|
||||
requests.append(create_request(id=i, inputs=input_text, truncate=truncate[i]))
|
||||
max_length = generator.model.max_length
|
||||
max_length = generator.max_prefill_length()
|
||||
batch = Batch(
|
||||
id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
|
||||
)
|
||||
@ -91,12 +87,12 @@ def test_prefill_truncate(neuron_model_config):
|
||||
# Even if the input text is identical for all requests, the first generated token might
|
||||
# be different because of the truncation
|
||||
expectations = {
|
||||
"gpt2": [" He", " He", "\n", " He"],
|
||||
"llama": [" —", " The", " He", " He"],
|
||||
"mistral": [" He", "\n", " He", " He"],
|
||||
"llama": [" He", "iens", "\x08", " He"],
|
||||
"qwen2": [" He", " The", " He", " He"],
|
||||
"granite": ["\n", "\n", " I", " He"],
|
||||
}[config_name]
|
||||
for i, g in enumerate(generations):
|
||||
tokens = g.tokens
|
||||
assert tokens.texts[0] == expectations[i]
|
||||
assert (
|
||||
tokens.texts[0] == expectations[i]
|
||||
), f"Request {i} expected [{expectations[i]}], got [{tokens.texts[0]}]"
|
||||
|
63
backends/neuron/tests/test_entry_point.py
Normal file
63
backends/neuron/tests/test_entry_point.py
Normal file
@ -0,0 +1,63 @@
|
||||
import os
|
||||
import pytest
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from optimum.neuron.models.inference.nxd.backend.config import NxDNeuronConfig
|
||||
from optimum.neuron.utils import map_torch_dtype
|
||||
|
||||
from text_generation_server.tgi_env import (
|
||||
get_neuron_config_for_model,
|
||||
lookup_compatible_cached_model,
|
||||
neuron_config_to_env,
|
||||
)
|
||||
|
||||
|
||||
def test_get_neuron_config_for_model(neuron_model_config):
|
||||
neuron_model_path = neuron_model_config["neuron_model_path"]
|
||||
export_kwargs = neuron_model_config["export_kwargs"]
|
||||
os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"])
|
||||
os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"])
|
||||
os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"]
|
||||
os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"])
|
||||
neuron_config = get_neuron_config_for_model(neuron_model_path)
|
||||
assert neuron_config is not None
|
||||
assert neuron_config.batch_size == export_kwargs["batch_size"]
|
||||
assert neuron_config.sequence_length == export_kwargs["sequence_length"]
|
||||
assert neuron_config.tp_degree == export_kwargs["num_cores"]
|
||||
if isinstance(neuron_config, NxDNeuronConfig):
|
||||
assert map_torch_dtype(neuron_config.torch_dtype) == map_torch_dtype(
|
||||
export_kwargs["auto_cast_type"]
|
||||
)
|
||||
else:
|
||||
assert map_torch_dtype(neuron_config.auto_cast_type) == map_torch_dtype(
|
||||
export_kwargs["auto_cast_type"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["unsloth/Llama-3.2-1B-Instruct"])
|
||||
def test_lookup_compatible_cached_model(model_id: str):
|
||||
neuron_config = lookup_compatible_cached_model(model_id, None)
|
||||
assert neuron_config is not None
|
||||
|
||||
|
||||
def test_neuron_config_to_env(neuron_model_config) -> None:
|
||||
neuron_model_path = neuron_model_config["neuron_model_path"]
|
||||
neuron_config = get_neuron_config_for_model(neuron_model_path)
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
os.environ["ENV_FILEPATH"] = os.path.join(temp_dir, "env.sh")
|
||||
neuron_config_to_env(neuron_config)
|
||||
with open(os.environ["ENV_FILEPATH"], "r") as env_file:
|
||||
env_content = env_file.read()
|
||||
assert f"export MAX_BATCH_SIZE={neuron_config.batch_size}" in env_content
|
||||
assert (
|
||||
f"export MAX_TOTAL_TOKENS={neuron_config.sequence_length}"
|
||||
in env_content
|
||||
)
|
||||
assert f"export HF_NUM_CORES={neuron_config.tp_degree}" in env_content
|
||||
if hasattr(neuron_config, "torch_dtype"):
|
||||
auto_cast_type = str(map_torch_dtype(neuron_config.torch_dtype)).split(
|
||||
"."
|
||||
)[-1]
|
||||
else:
|
||||
auto_cast_type = neuron_config.auto_cast_type
|
||||
assert f"export HF_AUTO_CAST_TYPE={auto_cast_type}" in env_content
|
@ -9,7 +9,7 @@ touch $ENV_FILEPATH
|
||||
|
||||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||
|
||||
${SCRIPT_DIR}/tgi_env.py $@
|
||||
${SCRIPT_DIR}/tgi_entry_point.py $@
|
||||
|
||||
source $ENV_FILEPATH
|
||||
|
||||
|
53
backends/neuron/tgi_entry_point.py
Executable file
53
backends/neuron/tgi_entry_point.py
Executable file
@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
from text_generation_server.tgi_env import (
|
||||
available_cores,
|
||||
get_env_dict,
|
||||
get_neuron_config_for_model,
|
||||
neuron_config_to_env,
|
||||
neuronxcc_version,
|
||||
parse_cmdline_and_set_env,
|
||||
tgi_env_vars,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
This script determines proper default TGI env variables for the neuron precompiled models to
|
||||
work properly
|
||||
:return:
|
||||
"""
|
||||
args = parse_cmdline_and_set_env()
|
||||
|
||||
for env_var in tgi_env_vars:
|
||||
if not os.getenv(env_var):
|
||||
break
|
||||
else:
|
||||
logger.info(
|
||||
"All env vars %s already set, skipping, user know what they are doing",
|
||||
tgi_env_vars,
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
neuron_config = get_neuron_config_for_model(args.model_id, args.revision)
|
||||
|
||||
if not neuron_config:
|
||||
msg = (
|
||||
"No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}"
|
||||
).format(get_env_dict(), available_cores, neuronxcc_version)
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
|
||||
neuron_config_to_env(neuron_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -10,7 +10,7 @@
|
||||
"name": "Apache 2.0",
|
||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||
},
|
||||
"version": "3.3.1-dev0"
|
||||
"version": "3.3.2-dev0"
|
||||
},
|
||||
"paths": {
|
||||
"/": {
|
||||
|
@ -20,7 +20,7 @@ hf_token=YOUR_HF_ACCESS_TOKEN
|
||||
|
||||
docker run --runtime=habana --cap-add=sys_nice --ipc=host \
|
||||
-p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
@ -52,7 +52,7 @@ hf_token=YOUR_ACCESS_TOKEN
|
||||
|
||||
docker run --runtime=habana --cap-add=sys_nice --ipc=host \
|
||||
-p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \
|
||||
--model-id $model
|
||||
<text-generation-inference-launcher-arguments>
|
||||
```
|
||||
@ -115,7 +115,7 @@ docker run -p 8080:80 \
|
||||
-e BATCH_BUCKET_SIZE=256 \
|
||||
-e PREFILL_BATCH_BUCKET_SIZE=4 \
|
||||
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \
|
||||
--model-id $model \
|
||||
--sharded true --num-shard 8 \
|
||||
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||
@ -141,7 +141,7 @@ docker run -p 8080:80 \
|
||||
-v $volume:/data \
|
||||
-e PREFILL_BATCH_BUCKET_SIZE=1 \
|
||||
-e BATCH_BUCKET_SIZE=1 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \
|
||||
--model-id $model \
|
||||
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
|
||||
--max-total-tokens 8192 --max-batch-size 4
|
||||
@ -208,7 +208,7 @@ docker run --runtime=habana --ipc=host --cap-add=sys_nice \
|
||||
-e PROF_PATH=/tmp/hpu_profile \
|
||||
-e PROF_RANKS=0 \
|
||||
-e PROF_RECORD_SHAPES=True \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1-gaudi \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
|
@ -31,7 +31,7 @@ deployment instructions in the model card:
|
||||
The service is launched simply by running the text-generation-inference container with two sets of parameters:
|
||||
|
||||
```
|
||||
docker run <system_parameters> ghcr.io/huggingface/text-generation-inference:3.3.1-neuron <service_parameters>
|
||||
docker run <system_parameters> ghcr.io/huggingface/text-generation-inference:3.3.2-neuron <service_parameters>
|
||||
```
|
||||
|
||||
- system parameters are used to map ports, volumes and devices between the host and the service,
|
||||
|
@ -19,6 +19,6 @@ docker run --gpus all \
|
||||
--shm-size 1g \
|
||||
-e HF_TOKEN=$token \
|
||||
-p 8080:80 \
|
||||
-v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 \
|
||||
-v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2 \
|
||||
--model-id $model
|
||||
```
|
||||
|
@ -19,7 +19,7 @@ bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models.
|
||||
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
|
||||
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model --quantize bitsandbytes
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model --quantize bitsandbytes
|
||||
```
|
||||
|
||||
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
|
||||
@ -27,7 +27,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf
|
||||
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
|
||||
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model --quantize bitsandbytes-nf4
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model --quantize bitsandbytes-nf4
|
||||
```
|
||||
|
||||
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
|
||||
@ -48,7 +48,7 @@ $$({\hat{W}_{l}}^{*} = argmin_{\hat{W_{l}}} ||W_{l}X-\hat{W}_{l}X||^{2}_{2})$$
|
||||
TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇
|
||||
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.1 --model-id $model --quantize gptq
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model --quantize gptq
|
||||
```
|
||||
|
||||
Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI.
|
||||
|
@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
|
||||
--device=/dev/kfd --device=/dev/dri --group-add video \
|
||||
--ipc=host --shm-size 256g --net host -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1-rocm \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.2-rocm \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
|
@ -12,7 +12,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||
docker run --rm --privileged --cap-add=sys_nice \
|
||||
--device=/dev/dri \
|
||||
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1-intel-xpu \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.2-intel-xpu \
|
||||
--model-id $model --cuda-graphs 0
|
||||
```
|
||||
|
||||
@ -29,7 +29,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||
docker run --rm --privileged --cap-add=sys_nice \
|
||||
--device=/dev/dri \
|
||||
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1-intel-cpu \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.2-intel-cpu \
|
||||
--model-id $model --cuda-graphs 0
|
||||
```
|
||||
|
||||
|
@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.2 \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
|
@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.1 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.3.2 \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
@ -96,7 +96,7 @@ curl 127.0.0.1:8080/generate \
|
||||
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
|
||||
|
||||
```bash
|
||||
docker run ghcr.io/huggingface/text-generation-inference:3.3.1 --help
|
||||
docker run ghcr.io/huggingface/text-generation-inference:3.3.2 --help
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
@ -163,7 +163,7 @@ hub = {
|
||||
|
||||
# create Hugging Face Model Class
|
||||
huggingface_model = HuggingFaceModel(
|
||||
image_uri=get_huggingface_llm_image_uri("huggingface",version="3.3.1"),
|
||||
image_uri=get_huggingface_llm_image_uri("huggingface",version="3.3.2"),
|
||||
env=hub,
|
||||
role=role,
|
||||
)
|
||||
|
@ -28,15 +28,6 @@ logger = logging.getLogger(__file__)
|
||||
|
||||
# All model configurations below will be added to the neuron_model_config fixture
|
||||
MODEL_CONFIGURATIONS = {
|
||||
"gpt2": {
|
||||
"model_id": "gpt2",
|
||||
"export_kwargs": {
|
||||
"batch_size": 4,
|
||||
"sequence_length": 1024,
|
||||
"num_cores": 2,
|
||||
"auto_cast_type": "fp16",
|
||||
},
|
||||
},
|
||||
"llama": {
|
||||
"model_id": "unsloth/Llama-3.2-1B-Instruct",
|
||||
"export_kwargs": {
|
||||
@ -46,15 +37,6 @@ MODEL_CONFIGURATIONS = {
|
||||
"auto_cast_type": "fp16",
|
||||
},
|
||||
},
|
||||
"mistral": {
|
||||
"model_id": "optimum/mistral-1.1b-testing",
|
||||
"export_kwargs": {
|
||||
"batch_size": 4,
|
||||
"sequence_length": 4096,
|
||||
"num_cores": 2,
|
||||
"auto_cast_type": "bf16",
|
||||
},
|
||||
},
|
||||
"qwen2": {
|
||||
"model_id": "Qwen/Qwen2.5-0.5B",
|
||||
"export_kwargs": {
|
||||
|
@ -17,7 +17,7 @@
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.3.1-dev0-native",
|
||||
"system_fingerprint": "3.3.2-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 42,
|
||||
"prompt_tokens": 277,
|
||||
|
@ -17,7 +17,7 @@
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.3.1-dev0-native",
|
||||
"system_fingerprint": "3.3.2-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 62,
|
||||
"prompt_tokens": 277,
|
||||
|
@ -17,7 +17,7 @@
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.3.1-dev0-native",
|
||||
"system_fingerprint": "3.3.2-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 67,
|
||||
"prompt_tokens": 277,
|
||||
|
@ -17,7 +17,7 @@
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.3.1-dev0-native",
|
||||
"system_fingerprint": "3.3.2-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 72,
|
||||
"prompt_tokens": 275,
|
||||
|
@ -17,7 +17,7 @@
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.3.1-dev0-native",
|
||||
"system_fingerprint": "3.3.2-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 80,
|
||||
"prompt_tokens": 279,
|
||||
|
@ -14,7 +14,7 @@
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.3.1-dev0-native",
|
||||
"system_fingerprint": "3.3.2-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 35,
|
||||
"prompt_tokens": 32,
|
||||
|
@ -14,7 +14,7 @@
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.3.1-dev0-native",
|
||||
"system_fingerprint": "3.3.2-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 44,
|
||||
"prompt_tokens": 37,
|
||||
|
@ -18,7 +18,7 @@
|
||||
"id": "",
|
||||
"model": "unsloth/Llama-3.2-11B-Vision-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.3.1-dev0-native",
|
||||
"system_fingerprint": "3.3.2-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 45,
|
||||
@ -44,7 +44,7 @@
|
||||
"id": "",
|
||||
"model": "unsloth/Llama-3.2-11B-Vision-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.3.1-dev0-native",
|
||||
"system_fingerprint": "3.3.2-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 45,
|
||||
|
@ -17,7 +17,7 @@
|
||||
"id": "",
|
||||
"model": "unsloth/Llama-3.2-11B-Vision-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.3.1-dev0-native",
|
||||
"system_fingerprint": "3.3.2-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 45,
|
||||
|
@ -20,9 +20,7 @@ async def test_model_single_request(tgi_service):
|
||||
)
|
||||
assert response.details.generated_tokens == 17
|
||||
greedy_expectations = {
|
||||
"gpt2": "\n\nDeep learning is a new field of research that has been around for a while",
|
||||
"llama": " and How Does it Work?\nDeep learning is a subset of machine learning that uses artificial",
|
||||
"mistral": "\nWhat is Deep Learning?\nDeep Learning is a type of machine learning that",
|
||||
"llama": " and how does it work?\nDeep learning is a subset of machine learning that uses artificial",
|
||||
"qwen2": " - Part 1\n\nDeep Learning is a subset of Machine Learning that is based on",
|
||||
"granite": "\n\nDeep Learning is a subset of Machine Learning, which is a branch of Art",
|
||||
}
|
||||
@ -79,9 +77,7 @@ async def test_model_multiple_requests(tgi_service, neuron_generate_load):
|
||||
|
||||
assert len(responses) == 4
|
||||
expectations = {
|
||||
"gpt2": "Deep learning is a new field of research that has been around for a while",
|
||||
"llama": "Deep learning is a subset of machine learning that uses artificial",
|
||||
"mistral": "Deep Learning is a type of machine learning that",
|
||||
"qwen2": "Deep Learning is a subset of Machine Learning that is based on",
|
||||
"granite": "Deep Learning is a subset of Machine Learning, which is a branch of Art",
|
||||
}
|
||||
|
@ -27,10 +27,6 @@ impl Env {
|
||||
docker_label: option_env!("DOCKER_LABEL").unwrap_or("N/A"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn should_start_a_single_hpu_shard(&self) -> bool {
|
||||
self.hpu_env != "N/A" && std::env::var("ATTENTION").as_deref() != Ok("paged")
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Env {
|
||||
|
@ -1590,11 +1590,6 @@ fn spawn_shards(
|
||||
) -> Result<(), LauncherError> {
|
||||
// Start shard processes
|
||||
for rank in 0..num_shard {
|
||||
if rank != 0 && env_runtime::Env::new().should_start_a_single_hpu_shard() {
|
||||
tracing::info!("Running on HPU, the launcher will not do any sharding as actual sharding is done in the server");
|
||||
break;
|
||||
}
|
||||
|
||||
let model_id = args.model_id.clone();
|
||||
let revision = args.revision.clone();
|
||||
let uds_path = args.shard_uds_path.clone();
|
||||
@ -1670,10 +1665,6 @@ fn spawn_shards(
|
||||
if shard_ready == num_shard {
|
||||
break;
|
||||
}
|
||||
if env_runtime::Env::new().should_start_a_single_hpu_shard() {
|
||||
tracing::info!("HPU detected, shard is ready");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(TryRecvError::Empty) => {
|
||||
sleep(Duration::from_millis(100));
|
||||
|
Loading…
Reference in New Issue
Block a user